Compare commits
4 Commits
v0.5.4_dev
...
0.5.3rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7993ed8ddd | ||
|
|
443a1b4ab3 | ||
|
|
852a49c5cc | ||
|
|
8f7453e3af |
@@ -57,7 +57,7 @@ dependencies = [
|
|||||||
"uvicorn",
|
"uvicorn",
|
||||||
"uvloop",
|
"uvloop",
|
||||||
"xgrammar==0.1.24",
|
"xgrammar==0.1.24",
|
||||||
"sgl-kernel==0.3.13",
|
"sgl-kernel==0.3.11",
|
||||||
"torch==2.8.0",
|
"torch==2.8.0",
|
||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
@@ -67,7 +67,7 @@ dependencies = [
|
|||||||
"tiktoken",
|
"tiktoken",
|
||||||
"anthropic>=0.20.0",
|
"anthropic>=0.20.0",
|
||||||
"torch_memory_saver==0.0.8",
|
"torch_memory_saver==0.0.8",
|
||||||
"nvidia-cutlass-dsl==4.2.1",
|
"nvidia-cutlass-dsl==4.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
|
|||||||
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
|
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
|
||||||
"srt/layers/quantization/configs/*.json",
|
"srt/layers/quantization/configs/*.json",
|
||||||
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
|
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
|
||||||
"srt/speculative/cpp_ngram/*.cpp",
|
"srt/speculative/cpp_lookahead/*.cpp",
|
||||||
"srt/speculative/cpp_ngram/*.h",
|
"srt/speculative/cpp_lookahead/*.h",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
|
|||||||
@@ -65,30 +65,29 @@ tracing = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.3.13",
|
"sgl-kernel==0.3.11",
|
||||||
"torch==2.8.0",
|
"torch==2.8.0",
|
||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"cuda-python",
|
"cuda-python",
|
||||||
"flashinfer_python==0.4.0rc1",
|
"flashinfer_python==0.3.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
blackwell = [
|
blackwell = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.3.13",
|
"sgl-kernel==0.3.11",
|
||||||
"torch==2.8.0",
|
"torch==2.8.0",
|
||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"cuda-python",
|
"cuda-python",
|
||||||
"flashinfer_python==0.4.0rc1",
|
"flashinfer_python==0.3.1",
|
||||||
"nvidia-cutlass-dsl==4.2.1",
|
"nvidia-cutlass-dsl==4.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
|
# => base docker rocm/vllm-dev:20250114, not from public vllm whl
|
||||||
srt_hip = [
|
srt_hip = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"torch",
|
|
||||||
"petit_kernel==0.0.2",
|
"petit_kernel==0.0.2",
|
||||||
"wave-lang==3.7.0",
|
"wave-lang==3.7.0",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -443,9 +443,11 @@ def latency_test_run_once(
|
|||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
profiler.stop()
|
profiler.stop()
|
||||||
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
||||||
_save_profile_trace_results(profiler, trace_filename)
|
_save_profile_trace_results(profiler, profile_filename)
|
||||||
rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
|
rank_print(
|
||||||
|
f"torch profiler chrome trace for prefill saved to {profile_filename}"
|
||||||
|
)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
decode_latencies = []
|
decode_latencies = []
|
||||||
@@ -477,10 +479,10 @@ def latency_test_run_once(
|
|||||||
|
|
||||||
if profile and i == output_len / 2:
|
if profile and i == output_len / 2:
|
||||||
profiler.stop()
|
profiler.stop()
|
||||||
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
||||||
_save_profile_trace_results(profiler, trace_filename)
|
_save_profile_trace_results(profiler, profile_filename)
|
||||||
rank_print(
|
rank_print(
|
||||||
f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
|
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Record decode timing from 2nd output
|
# Record decode timing from 2nd output
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
|
|||||||
|
|
||||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
||||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
|
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
|
||||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -20,17 +19,12 @@ import multiprocessing
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from sglang.bench_serving import (
|
from sglang.bench_serving import get_tokenizer, sample_random_requests
|
||||||
get_tokenizer,
|
|
||||||
sample_mmmu_requests,
|
|
||||||
sample_random_requests,
|
|
||||||
)
|
|
||||||
from sglang.profiler import run_profile
|
from sglang.profiler import run_profile
|
||||||
from sglang.srt.entrypoints.http_server import launch_server
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -38,108 +32,6 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
|
|||||||
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
||||||
|
|
||||||
|
|
||||||
class ProfileLinks(BaseModel):
|
|
||||||
"""Pydantic model for profile trace links."""
|
|
||||||
|
|
||||||
extend: Optional[str] = None
|
|
||||||
decode: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkResult(BaseModel):
|
|
||||||
"""Pydantic model for benchmark results table data, for a single isl and osl"""
|
|
||||||
|
|
||||||
model_path: str
|
|
||||||
run_name: str
|
|
||||||
batch_size: int
|
|
||||||
input_len: int
|
|
||||||
output_len: int
|
|
||||||
latency: float
|
|
||||||
ttft: float
|
|
||||||
input_throughput: float
|
|
||||||
output_throughput: float
|
|
||||||
overall_throughput: float
|
|
||||||
last_gen_throughput: float
|
|
||||||
acc_length: Optional[float] = None
|
|
||||||
profile_links: Optional[ProfileLinks] = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def help_str() -> str:
|
|
||||||
return f"""
|
|
||||||
Note: To view the traces through perfetto-ui, please:
|
|
||||||
1. open with Google Chrome
|
|
||||||
2. allow popup
|
|
||||||
"""
|
|
||||||
|
|
||||||
def to_markdown_row(
|
|
||||||
self, trace_dir, base_url: str = "", relay_base: str = ""
|
|
||||||
) -> str:
|
|
||||||
"""Convert this benchmark result to a markdown table row."""
|
|
||||||
# Calculate costs (assuming H100 pricing for now)
|
|
||||||
hourly_cost_per_gpu = 2 # $2/hour for one H100
|
|
||||||
hourly_cost = hourly_cost_per_gpu * 1 # Assuming tp_size = 1 for simplicity
|
|
||||||
input_util = 0.7
|
|
||||||
accept_length = (
|
|
||||||
round(self.acc_length, 2) if self.acc_length is not None else "n/a"
|
|
||||||
)
|
|
||||||
itl = 1 / (self.output_throughput / self.batch_size) * 1000
|
|
||||||
input_cost = 1e6 / (self.input_throughput * input_util) / 3600 * hourly_cost
|
|
||||||
output_cost = 1e6 / self.output_throughput / 3600 * hourly_cost
|
|
||||||
|
|
||||||
def get_perfetto_relay_link_from_trace_file(trace_file: str):
|
|
||||||
import os
|
|
||||||
from urllib.parse import quote
|
|
||||||
|
|
||||||
rel_path = os.path.relpath(trace_file, trace_dir)
|
|
||||||
raw_file_link = f"{base_url}/{rel_path}"
|
|
||||||
relay_link = (
|
|
||||||
f"{relay_base}?src={quote(raw_file_link, safe='')}"
|
|
||||||
if relay_base and quote
|
|
||||||
else raw_file_link
|
|
||||||
)
|
|
||||||
return relay_link
|
|
||||||
|
|
||||||
# Handle profile links
|
|
||||||
profile_link = "NA | NA"
|
|
||||||
if self.profile_links:
|
|
||||||
if self.profile_links.extend or self.profile_links.decode:
|
|
||||||
# Create a combined link or use the first available one
|
|
||||||
trace_files = [self.profile_links.extend, self.profile_links.decode]
|
|
||||||
trace_files_relay_links = [
|
|
||||||
f"[trace]({get_perfetto_relay_link_from_trace_file(trace_file)})"
|
|
||||||
for trace_file in trace_files
|
|
||||||
]
|
|
||||||
|
|
||||||
profile_link = " | ".join(trace_files_relay_links)
|
|
||||||
|
|
||||||
# Build the row
|
|
||||||
return f"| {self.batch_size} | {self.input_len} | {self.latency:.2f} | {self.input_throughput:.2f} | {self.output_throughput:.2f} | {accept_length} | {itl:.2f} | {input_cost:.2f} | {output_cost:.2f} | {profile_link} |\n"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate_markdown_report(
|
|
||||||
cls, trace_dir, results: List["BenchmarkResult"]
|
|
||||||
) -> str:
|
|
||||||
"""Generate a markdown report from a list of BenchmarkResult object from a single run."""
|
|
||||||
import os
|
|
||||||
|
|
||||||
summary = f"### {results[0].model_path}\n"
|
|
||||||
|
|
||||||
# summary += (
|
|
||||||
# f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
|
|
||||||
# )
|
|
||||||
summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|\n"
|
|
||||||
summary += "| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |\n"
|
|
||||||
|
|
||||||
# all results should share the same isl & osl
|
|
||||||
for result in results:
|
|
||||||
base_url = os.getenv("TRACE_BASE_URL", "").rstrip("/")
|
|
||||||
relay_base = os.getenv("PERFETTO_RELAY_URL", "").rstrip("/")
|
|
||||||
relay_base = "https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
|
|
||||||
# base_url = "https://github.com/sgl-project/ci-data/traces"
|
|
||||||
summary += result.to_markdown_row(trace_dir, base_url, relay_base)
|
|
||||||
|
|
||||||
return summary
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class BenchArgs:
|
class BenchArgs:
|
||||||
run_name: str = "default"
|
run_name: str = "default"
|
||||||
@@ -158,12 +50,8 @@ class BenchArgs:
|
|||||||
profile: bool = False
|
profile: bool = False
|
||||||
profile_steps: int = 3
|
profile_steps: int = 3
|
||||||
profile_by_stage: bool = False
|
profile_by_stage: bool = False
|
||||||
profile_filename_prefix: str = None
|
|
||||||
append_to_github_summary: bool = True
|
|
||||||
dataset_path: str = ""
|
dataset_path: str = ""
|
||||||
parallel_batch: bool = False
|
parallel_batch: bool = False
|
||||||
dataset_name: str = "random"
|
|
||||||
output_path: Optional[str] = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
@@ -179,13 +67,6 @@ class BenchArgs:
|
|||||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||||
)
|
)
|
||||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||||
parser.add_argument(
|
|
||||||
"--dataset-name",
|
|
||||||
type=str,
|
|
||||||
default=BenchArgs.dataset_name,
|
|
||||||
choices=["mmmu", "random"],
|
|
||||||
help="Name of the dataset to benchmark on.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--return-logprob", action="store_true")
|
parser.add_argument("--return-logprob", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--client-stream-interval",
|
"--client-stream-interval",
|
||||||
@@ -215,36 +96,14 @@ class BenchArgs:
|
|||||||
help="Path to the dataset.",
|
help="Path to the dataset.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--parallel-batch", action="store_true")
|
parser.add_argument("--parallel-batch", action="store_true")
|
||||||
parser.add_argument(
|
|
||||||
"--profile-filename-prefix",
|
|
||||||
type=str,
|
|
||||||
default=BenchArgs.profile_filename_prefix,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--no-append-to-github-summary",
|
|
||||||
action="store_false",
|
|
||||||
dest="append_to_github_summary",
|
|
||||||
help="Disable appending the output of this run to github ci summary",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-path",
|
|
||||||
type=str,
|
|
||||||
default=BenchArgs.output_path,
|
|
||||||
help="Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
# use the default value's type to cast the args into correct types.
|
# use the default value's type to cast the args into correct types.
|
||||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||||
kwargs = {}
|
return cls(
|
||||||
for attr, attr_type in attrs:
|
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||||
val = getattr(args, attr)
|
)
|
||||||
if attr_type is type(None):
|
|
||||||
kwargs[attr] = val
|
|
||||||
else:
|
|
||||||
kwargs[attr] = attr_type(val)
|
|
||||||
return cls(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_server_internal(server_args):
|
def launch_server_internal(server_args):
|
||||||
@@ -289,35 +148,23 @@ def run_one_case(
|
|||||||
run_name: str,
|
run_name: str,
|
||||||
result_filename: str,
|
result_filename: str,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
dataset_name="",
|
|
||||||
profile: bool = False,
|
profile: bool = False,
|
||||||
profile_steps: int = 3,
|
profile_steps: int = 3,
|
||||||
profile_by_stage: bool = False,
|
profile_by_stage: bool = False,
|
||||||
profile_filename_prefix: str = None,
|
|
||||||
dataset_path: str = "",
|
dataset_path: str = "",
|
||||||
parallel_batch: bool = False,
|
parallel_batch: bool = False,
|
||||||
):
|
):
|
||||||
requests.post(url + "/flush_cache")
|
requests.post(url + "/flush_cache")
|
||||||
# TODO: reuse bench_serving.get_dataset ?
|
input_requests = sample_random_requests(
|
||||||
if dataset_name == "mmmu":
|
input_len=input_len,
|
||||||
input_requests = sample_mmmu_requests(
|
output_len=output_len,
|
||||||
num_requests=batch_size,
|
num_prompts=batch_size,
|
||||||
tokenizer=tokenizer,
|
range_ratio=1.0,
|
||||||
fixed_output_len=output_len,
|
tokenizer=tokenizer,
|
||||||
apply_chat_template=True,
|
dataset_path=dataset_path,
|
||||||
random_sample=False,
|
random_sample=True,
|
||||||
)
|
return_text=False,
|
||||||
elif dataset_name == "random":
|
)
|
||||||
input_requests = sample_random_requests(
|
|
||||||
input_len=input_len,
|
|
||||||
output_len=output_len,
|
|
||||||
num_prompts=batch_size,
|
|
||||||
range_ratio=1.0,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
dataset_path=dataset_path,
|
|
||||||
random_sample=True,
|
|
||||||
return_text=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_structured_outputs = False
|
use_structured_outputs = False
|
||||||
if use_structured_outputs:
|
if use_structured_outputs:
|
||||||
@@ -334,48 +181,26 @@ def run_one_case(
|
|||||||
|
|
||||||
profile_link = None
|
profile_link = None
|
||||||
if profile:
|
if profile:
|
||||||
output_dir, profile_name = None, None
|
|
||||||
if profile_filename_prefix:
|
|
||||||
output_dir = os.path.dirname(profile_filename_prefix)
|
|
||||||
profile_name = os.path.basename(profile_filename_prefix)
|
|
||||||
profile_link: str = run_profile(
|
profile_link: str = run_profile(
|
||||||
url,
|
url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
||||||
profile_steps,
|
|
||||||
["CPU", "GPU"],
|
|
||||||
output_dir,
|
|
||||||
profile_name,
|
|
||||||
profile_by_stage,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
|
|
||||||
payload = {
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_new_tokens": output_len,
|
|
||||||
"ignore_eos": True,
|
|
||||||
"json_schema": json_schema,
|
|
||||||
"stream_interval": stream_interval,
|
|
||||||
},
|
|
||||||
"return_logprob": return_logprob,
|
|
||||||
"stream": True,
|
|
||||||
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
|
|
||||||
}
|
|
||||||
if dataset_name == "mmmu":
|
|
||||||
# vlm
|
|
||||||
input_ids = []
|
|
||||||
for input_req in input_requests:
|
|
||||||
input_ids += [tokenizer.encode(input_req.prompt)]
|
|
||||||
payload["image_data"] = [req.image_data for req in input_requests]
|
|
||||||
|
|
||||||
else:
|
|
||||||
input_ids = [req.prompt for req in input_requests]
|
|
||||||
|
|
||||||
payload["input_ids"] = input_ids
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json=payload,
|
json={
|
||||||
|
"input_ids": [req.prompt for req in input_requests],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_new_tokens": output_len,
|
||||||
|
"ignore_eos": True,
|
||||||
|
"json_schema": json_schema,
|
||||||
|
"stream_interval": stream_interval,
|
||||||
|
},
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"stream": True,
|
||||||
|
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
|
||||||
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -439,100 +264,10 @@ def run_one_case(
|
|||||||
overall_throughput,
|
overall_throughput,
|
||||||
last_gen_throughput,
|
last_gen_throughput,
|
||||||
acc_length,
|
acc_length,
|
||||||
profile_link,
|
profile_link if profile else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_results_as_json(result: List[Tuple], bench_args: BenchArgs, model: str):
|
|
||||||
"""Save benchmark results as JSON using Pydantic models."""
|
|
||||||
json_results = []
|
|
||||||
|
|
||||||
# Generate all parameter combinations to match with results
|
|
||||||
param_combinations = list(
|
|
||||||
itertools.product(
|
|
||||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, (
|
|
||||||
batch_size,
|
|
||||||
latency,
|
|
||||||
ttft,
|
|
||||||
input_throughput,
|
|
||||||
output_throughput,
|
|
||||||
overall_throughput,
|
|
||||||
last_gen_throughput,
|
|
||||||
acc_length,
|
|
||||||
profile_link,
|
|
||||||
) in enumerate(result):
|
|
||||||
# Get the corresponding parameters for this result
|
|
||||||
bs, input_len, output_len = param_combinations[i]
|
|
||||||
|
|
||||||
# Parse profile links if available
|
|
||||||
profile_links = None
|
|
||||||
if profile_link:
|
|
||||||
profile_links = parse_profile_links(
|
|
||||||
profile_link, batch_size, input_len, output_len
|
|
||||||
)
|
|
||||||
|
|
||||||
benchmark_result = BenchmarkResult(
|
|
||||||
model_path=model,
|
|
||||||
run_name=bench_args.run_name,
|
|
||||||
batch_size=batch_size,
|
|
||||||
input_len=input_len,
|
|
||||||
output_len=output_len,
|
|
||||||
latency=latency,
|
|
||||||
ttft=ttft,
|
|
||||||
input_throughput=input_throughput,
|
|
||||||
output_throughput=output_throughput,
|
|
||||||
overall_throughput=overall_throughput,
|
|
||||||
last_gen_throughput=last_gen_throughput,
|
|
||||||
acc_length=acc_length,
|
|
||||||
profile_links=profile_links,
|
|
||||||
)
|
|
||||||
json_results.append(benchmark_result.model_dump())
|
|
||||||
|
|
||||||
# Save to JSON file
|
|
||||||
with open(bench_args.output_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(json_results, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
print(f"Results saved as JSON to {bench_args.output_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_profile_links(
|
|
||||||
profile_dir: str, batch_size: int, input_len: int, output_len: int
|
|
||||||
) -> Optional[ProfileLinks]:
|
|
||||||
"""Parse profile directory to extract extend and decode trace file links."""
|
|
||||||
if not profile_dir or not os.path.exists(profile_dir):
|
|
||||||
return None
|
|
||||||
|
|
||||||
extend_link = None
|
|
||||||
decode_link = None
|
|
||||||
|
|
||||||
# Look for extend/prefill trace files
|
|
||||||
for file in os.listdir(profile_dir):
|
|
||||||
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
|
|
||||||
if "extend" in file.lower() or "prefill" in file.lower():
|
|
||||||
extend_link = os.path.join(profile_dir, file)
|
|
||||||
elif "decode" in file.lower():
|
|
||||||
decode_link = os.path.join(profile_dir, file)
|
|
||||||
|
|
||||||
# If no specific extend/decode files found, try to find files with batch/input/output info
|
|
||||||
if not extend_link or not decode_link:
|
|
||||||
for file in os.listdir(profile_dir):
|
|
||||||
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
|
|
||||||
if f"_batch{batch_size}_input{input_len}_output{output_len}_" in file:
|
|
||||||
if "prefill" in file.lower() or "extend" in file.lower():
|
|
||||||
extend_link = os.path.join(profile_dir, file)
|
|
||||||
elif "decode" in file.lower():
|
|
||||||
decode_link = os.path.join(profile_dir, file)
|
|
||||||
|
|
||||||
if extend_link or decode_link:
|
|
||||||
return ProfileLinks(extend=extend_link, decode=decode_link)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_report_summary(
|
def get_report_summary(
|
||||||
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
|
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
|
||||||
):
|
):
|
||||||
@@ -623,7 +358,6 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|||||||
return_logprob=bench_args.return_logprob,
|
return_logprob=bench_args.return_logprob,
|
||||||
stream_interval=bench_args.client_stream_interval,
|
stream_interval=bench_args.client_stream_interval,
|
||||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||||
dataset_name=bench_args.dataset_name,
|
|
||||||
run_name="",
|
run_name="",
|
||||||
result_filename="",
|
result_filename="",
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@@ -650,12 +384,10 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|||||||
stream_interval=bench_args.client_stream_interval,
|
stream_interval=bench_args.client_stream_interval,
|
||||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||||
run_name=bench_args.run_name,
|
run_name=bench_args.run_name,
|
||||||
dataset_name=bench_args.dataset_name,
|
|
||||||
result_filename=bench_args.result_filename,
|
result_filename=bench_args.result_filename,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dataset_path=bench_args.dataset_path,
|
dataset_path=bench_args.dataset_path,
|
||||||
parallel_batch=bench_args.parallel_batch,
|
parallel_batch=bench_args.parallel_batch,
|
||||||
profile_filename_prefix=bench_args.profile_filename_prefix,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -678,13 +410,11 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|||||||
run_name=bench_args.run_name,
|
run_name=bench_args.run_name,
|
||||||
result_filename=bench_args.result_filename,
|
result_filename=bench_args.result_filename,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dataset_name=bench_args.dataset_name,
|
|
||||||
profile=bench_args.profile,
|
profile=bench_args.profile,
|
||||||
profile_steps=bench_args.profile_steps,
|
profile_steps=bench_args.profile_steps,
|
||||||
profile_by_stage=bench_args.profile_by_stage,
|
profile_by_stage=bench_args.profile_by_stage,
|
||||||
dataset_path=bench_args.dataset_path,
|
dataset_path=bench_args.dataset_path,
|
||||||
parallel_batch=bench_args.parallel_batch,
|
parallel_batch=bench_args.parallel_batch,
|
||||||
profile_filename_prefix=bench_args.profile_filename_prefix,
|
|
||||||
)[-1],
|
)[-1],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -697,16 +427,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|||||||
|
|
||||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||||
|
|
||||||
# Save results as JSON if output_path is specified
|
|
||||||
if bench_args.output_path:
|
|
||||||
save_results_as_json(result, bench_args, model=server_args.model_path)
|
|
||||||
|
|
||||||
if not bench_args.show_report:
|
if not bench_args.show_report:
|
||||||
return
|
return
|
||||||
|
|
||||||
summary = get_report_summary(result, server_args, bench_args)
|
summary = get_report_summary(result, server_args, bench_args)
|
||||||
|
print(summary)
|
||||||
|
|
||||||
if is_in_ci() and bench_args.append_to_github_summary:
|
if is_in_ci():
|
||||||
write_github_step_summary(summary)
|
write_github_step_summary(summary)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -208,10 +208,6 @@ async def async_request_openai_completions(
|
|||||||
"ignore_eos": not args.disable_ignore_eos,
|
"ignore_eos": not args.disable_ignore_eos,
|
||||||
**request_func_input.extra_request_body,
|
**request_func_input.extra_request_body,
|
||||||
}
|
}
|
||||||
|
|
||||||
if request_func_input.image_data:
|
|
||||||
payload.update({"image_data": request_func_input.image_data})
|
|
||||||
|
|
||||||
headers = get_auth_headers()
|
headers = get_auth_headers()
|
||||||
|
|
||||||
output = RequestFuncOutput.init_new(request_func_input)
|
output = RequestFuncOutput.init_new(request_func_input)
|
||||||
@@ -1763,9 +1759,7 @@ async def benchmark(
|
|||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
if "sglang" in backend:
|
if "sglang" in backend:
|
||||||
server_info = requests.get(
|
server_info = requests.get(base_url + "/get_server_info")
|
||||||
base_url + "/get_server_info", headers=get_auth_headers()
|
|
||||||
)
|
|
||||||
if server_info.status_code == 200:
|
if server_info.status_code == 200:
|
||||||
server_info_json = server_info.json()
|
server_info_json = server_info.json()
|
||||||
if "decode" in server_info_json:
|
if "decode" in server_info_json:
|
||||||
|
|||||||
@@ -124,8 +124,6 @@ class Envs:
|
|||||||
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
|
SGLANG_TEST_REQUEST_TIME_STATS = EnvBool(False)
|
||||||
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
|
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK = EnvBool(False)
|
||||||
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
|
SGLANG_DISABLE_REQUEST_LOGGING = EnvBool(False)
|
||||||
SGLANG_SIMULATE_ACC_LEN = EnvFloat(-1)
|
|
||||||
SGLANG_SIMULATE_ACC_METHOD = EnvStr("multinomial")
|
|
||||||
|
|
||||||
# Model Parallel
|
# Model Parallel
|
||||||
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
|
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER = EnvBool(True)
|
||||||
@@ -37,8 +37,8 @@ class GlobalConfig:
|
|||||||
)
|
)
|
||||||
# Runtime constants: others
|
# Runtime constants: others
|
||||||
self.retract_decode_steps = 20
|
self.retract_decode_steps = 20
|
||||||
self.flashinfer_workspace_size = int(
|
self.flashinfer_workspace_size = os.environ.get(
|
||||||
os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
|
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
# Output tokenization configs
|
# Output tokenization configs
|
||||||
|
|||||||
@@ -7,23 +7,9 @@ from sglang.srt.entrypoints.http_server import launch_server
|
|||||||
from sglang.srt.server_args import prepare_server_args
|
from sglang.srt.server_args import prepare_server_args
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
|
||||||
MOVE_ENVS_WARN = """
|
|
||||||
########################################################################
|
|
||||||
# For contributors and developers: #
|
|
||||||
# Please move environment variable definitions to sglang.srt.environ #
|
|
||||||
# using the following pattern: #
|
|
||||||
# SGLANG_XXX = EnvBool(False) #
|
|
||||||
# #
|
|
||||||
########################################################################
|
|
||||||
"""
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
server_args = prepare_server_args(sys.argv[1:])
|
server_args = prepare_server_args(sys.argv[1:])
|
||||||
|
|
||||||
from sglang.srt.server_args import print_deprecated_warning
|
|
||||||
|
|
||||||
print_deprecated_warning(MOVE_ENVS_WARN)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
launch_server(server_args)
|
launch_server(server_args)
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -24,8 +24,6 @@ class LoadFormat(str, enum.Enum):
|
|||||||
JAX = "jax"
|
JAX = "jax"
|
||||||
REMOTE = "remote"
|
REMOTE = "remote"
|
||||||
REMOTE_INSTANCE = "remote_instance"
|
REMOTE_INSTANCE = "remote_instance"
|
||||||
RDMA = "rdma"
|
|
||||||
LOCAL_CACHED = "local_cached"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -49,7 +47,6 @@ class LoadConfig:
|
|||||||
checkpoints.
|
checkpoints.
|
||||||
decryption_key_file: If set, decrypts the output files with a password read
|
decryption_key_file: If set, decrypts the output files with a password read
|
||||||
from this file (after PBKDF2).
|
from this file (after PBKDF2).
|
||||||
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||||
@@ -57,11 +54,6 @@ class LoadConfig:
|
|||||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||||
decryption_key_file: Optional[str] = None
|
decryption_key_file: Optional[str] = None
|
||||||
decrypt_max_concurrency: int = -1
|
|
||||||
tp_rank: Optional[int] = None
|
|
||||||
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
|
|
||||||
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
|
|
||||||
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip, retry
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||||
from sglang.utils import is_in_ci
|
from sglang.utils import is_in_ci
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -48,6 +48,30 @@ class ModelImpl(str, Enum):
|
|||||||
TRANSFORMERS = "transformers"
|
TRANSFORMERS = "transformers"
|
||||||
|
|
||||||
|
|
||||||
|
def is_deepseek_nsa(config: PretrainedConfig) -> bool:
|
||||||
|
return (
|
||||||
|
config.architectures is not None
|
||||||
|
and config.architectures[0]
|
||||||
|
in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
|
||||||
|
and getattr(config, "index_topk", None) is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_nsa_index_head_dim(config: PretrainedConfig) -> int:
|
||||||
|
assert is_deepseek_nsa(config)
|
||||||
|
return config.index_head_dim
|
||||||
|
|
||||||
|
|
||||||
|
def get_nsa_index_topk(config: PretrainedConfig) -> int:
|
||||||
|
assert is_deepseek_nsa(config)
|
||||||
|
return config.index_topk
|
||||||
|
|
||||||
|
|
||||||
|
def get_nsa_index_n_heads(config: PretrainedConfig) -> int:
|
||||||
|
assert is_deepseek_nsa(config)
|
||||||
|
return config.index_n_heads
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -64,20 +88,35 @@ class ModelConfig:
|
|||||||
is_draft_model: bool = False,
|
is_draft_model: bool = False,
|
||||||
hybrid_kvcache_ratio: Optional[float] = None,
|
hybrid_kvcache_ratio: Optional[float] = None,
|
||||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||||
|
tp_rank: Optional[int] = None,
|
||||||
|
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports: Optional[
|
||||||
|
List[int]
|
||||||
|
] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
self.quantization = quantization
|
self.quantization = quantization
|
||||||
self.is_draft_model = is_draft_model
|
|
||||||
self.model_impl = model_impl
|
self.model_impl = model_impl
|
||||||
|
self.tp_rank = tp_rank
|
||||||
|
self.remote_instance_weight_loader_seed_instance_ip = (
|
||||||
|
remote_instance_weight_loader_seed_instance_ip
|
||||||
|
)
|
||||||
|
self.remote_instance_weight_loader_seed_instance_service_port = (
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port
|
||||||
|
)
|
||||||
|
self.remote_instance_weight_loader_send_weights_group_ports = (
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports
|
||||||
|
)
|
||||||
|
|
||||||
# Get hf config
|
self.maybe_pull_model_tokenizer_from_remote()
|
||||||
self._maybe_pull_model_tokenizer_from_remote()
|
|
||||||
self.model_override_args = json.loads(model_override_args)
|
self.model_override_args = json.loads(model_override_args)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if override_config_file and override_config_file.strip():
|
if override_config_file and override_config_file.strip():
|
||||||
kwargs["_configuration_file"] = override_config_file.strip()
|
kwargs["_configuration_file"] = override_config_file.strip()
|
||||||
|
|
||||||
self.hf_config = get_config(
|
self.hf_config = get_config(
|
||||||
self.model_path,
|
self.model_path,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
@@ -85,7 +124,7 @@ class ModelConfig:
|
|||||||
model_override_args=self.model_override_args,
|
model_override_args=self.model_override_args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
|
||||||
self.hf_generation_config = get_generation_config(
|
self.hf_generation_config = get_generation_config(
|
||||||
self.model_path,
|
self.model_path,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
@@ -93,25 +132,7 @@ class ModelConfig:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set enable_multimodal
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
if enable_multimodal is None:
|
|
||||||
mm_disabled_models = [
|
|
||||||
"Gemma3ForConditionalGeneration",
|
|
||||||
"Llama4ForConditionalGeneration",
|
|
||||||
"Step3VLForConditionalGeneration",
|
|
||||||
]
|
|
||||||
if self.hf_config.architectures[0] in mm_disabled_models:
|
|
||||||
enable_multimodal = False
|
|
||||||
logger.info(
|
|
||||||
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
enable_multimodal = True
|
|
||||||
|
|
||||||
# Config draft model
|
|
||||||
self._config_draft_model()
|
|
||||||
|
|
||||||
# Check model type
|
|
||||||
self.attention_chunk_size = getattr(
|
self.attention_chunk_size = getattr(
|
||||||
self.hf_text_config, "attention_chunk_size", None
|
self.hf_text_config, "attention_chunk_size", None
|
||||||
)
|
)
|
||||||
@@ -127,70 +148,20 @@ class ModelConfig:
|
|||||||
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
|
self.hf_config.architectures, self.hf_text_config.num_hidden_layers
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.is_generation = is_generation_model(
|
|
||||||
self.hf_config.architectures, is_embedding
|
|
||||||
)
|
|
||||||
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
|
||||||
self.hf_config.architectures
|
|
||||||
)
|
|
||||||
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
|
||||||
self.hf_config.architectures
|
|
||||||
)
|
|
||||||
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
|
||||||
self.hf_config.architectures
|
|
||||||
)
|
|
||||||
self.is_audio_model = enable_multimodal and is_audio_model(
|
|
||||||
self.hf_config.architectures
|
|
||||||
)
|
|
||||||
self.is_multimodal_chunked_prefill_supported = (
|
|
||||||
enable_multimodal
|
|
||||||
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
|
||||||
)
|
|
||||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
|
||||||
|
|
||||||
# Derive context length and model shapes
|
if enable_multimodal is None:
|
||||||
self._derive_context_length(context_length)
|
mm_disabled_models = [
|
||||||
self._derive_model_shapes()
|
"Gemma3ForConditionalGeneration",
|
||||||
|
"Llama4ForConditionalGeneration",
|
||||||
# Verify quantization
|
"Step3VLForConditionalGeneration",
|
||||||
self._verify_quantization()
|
]
|
||||||
|
if self.hf_config.architectures[0] in mm_disabled_models:
|
||||||
# Verify dual-chunk attention config
|
enable_multimodal = False
|
||||||
self._verify_dual_chunk_attention_config()
|
logger.info(
|
||||||
|
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
|
||||||
# Cache attributes
|
)
|
||||||
self.hf_eos_token_id = self._get_hf_eos_token_id()
|
else:
|
||||||
|
enable_multimodal = True
|
||||||
# multimodal
|
|
||||||
self.image_token_id = getattr(
|
|
||||||
self.hf_config, "image_token_id", None
|
|
||||||
) or getattr(self.hf_config, "image_token_index", None)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_server_args(
|
|
||||||
server_args: ServerArgs,
|
|
||||||
model_path: str = None,
|
|
||||||
model_revision: str = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
return ModelConfig(
|
|
||||||
model_path=model_path or server_args.model_path,
|
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
revision=model_revision or server_args.revision,
|
|
||||||
context_length=server_args.context_length,
|
|
||||||
model_override_args=server_args.json_model_override_args,
|
|
||||||
is_embedding=server_args.is_embedding,
|
|
||||||
enable_multimodal=server_args.enable_multimodal,
|
|
||||||
dtype=server_args.dtype,
|
|
||||||
quantization=server_args.quantization,
|
|
||||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
|
||||||
model_impl=server_args.model_impl,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _config_draft_model(self):
|
|
||||||
is_draft_model = self.is_draft_model
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
is_draft_model
|
is_draft_model
|
||||||
@@ -225,10 +196,31 @@ class ModelConfig:
|
|||||||
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
|
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
|
||||||
self.hf_config.num_nextn_predict_layers = 1
|
self.hf_config.num_nextn_predict_layers = 1
|
||||||
|
|
||||||
def _derive_context_length(self, context_length: int):
|
# Check model type
|
||||||
is_draft_model = self.is_draft_model
|
self.is_generation = is_generation_model(
|
||||||
derived_context_len = get_context_length(self.hf_text_config)
|
self.hf_config.architectures, is_embedding
|
||||||
|
)
|
||||||
|
self.is_multimodal = enable_multimodal and is_multimodal_model(
|
||||||
|
self.hf_config.architectures
|
||||||
|
)
|
||||||
|
self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
|
||||||
|
self.hf_config.architectures
|
||||||
|
)
|
||||||
|
self.is_image_gen = enable_multimodal and is_image_gen_model(
|
||||||
|
self.hf_config.architectures
|
||||||
|
)
|
||||||
|
self.is_audio_model = enable_multimodal and is_audio_model(
|
||||||
|
self.hf_config.architectures
|
||||||
|
)
|
||||||
|
self.is_multimodal_chunked_prefill_supported = (
|
||||||
|
enable_multimodal
|
||||||
|
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
|
||||||
|
)
|
||||||
|
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||||
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
|
|
||||||
|
# Derive context length
|
||||||
|
derived_context_len = get_context_length(self.hf_text_config)
|
||||||
if context_length is not None:
|
if context_length is not None:
|
||||||
if context_length > derived_context_len:
|
if context_length > derived_context_len:
|
||||||
reason = "Target model's" if is_draft_model else "User-specified"
|
reason = "Target model's" if is_draft_model else "User-specified"
|
||||||
@@ -242,11 +234,6 @@ class ModelConfig:
|
|||||||
):
|
):
|
||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
self.context_len = context_length
|
self.context_len = context_length
|
||||||
if is_draft_model:
|
|
||||||
self.hf_text_config.max_position_embeddings = context_length
|
|
||||||
logger.warning(
|
|
||||||
f"Overriding the draft model's max_position_embeddings to {context_length}."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
||||||
@@ -256,10 +243,6 @@ class ModelConfig:
|
|||||||
else:
|
else:
|
||||||
self.context_len = derived_context_len
|
self.context_len = derived_context_len
|
||||||
|
|
||||||
# Transfer context_len to HuggingFace config so models can access it
|
|
||||||
self.hf_config.context_len = self.context_len
|
|
||||||
|
|
||||||
def _derive_model_shapes(self):
|
|
||||||
# Unify the config keys for hf_text_config
|
# Unify the config keys for hf_text_config
|
||||||
self.head_dim = getattr(
|
self.head_dim = getattr(
|
||||||
self.hf_text_config,
|
self.hf_text_config,
|
||||||
@@ -270,6 +253,7 @@ class ModelConfig:
|
|||||||
# FIXME: temporary special judge for MLA architecture
|
# FIXME: temporary special judge for MLA architecture
|
||||||
if (
|
if (
|
||||||
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
"DeepseekV2ForCausalLM" in self.hf_config.architectures
|
||||||
|
or "DeepseekV32ForCausalLM" in self.hf_config.architectures
|
||||||
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
|
||||||
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
|
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
|
||||||
or "LongcatFlashForCausalLM" in self.hf_config.architectures
|
or "LongcatFlashForCausalLM" in self.hf_config.architectures
|
||||||
@@ -282,6 +266,11 @@ class ModelConfig:
|
|||||||
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
|
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
|
||||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||||
self.v_head_dim = self.hf_config.v_head_dim
|
self.v_head_dim = self.hf_config.v_head_dim
|
||||||
|
self.index_head_dim = (
|
||||||
|
get_nsa_index_head_dim(self.hf_config)
|
||||||
|
if is_deepseek_nsa(self.hf_config)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# Handle rope scaling with yarn
|
# Handle rope scaling with yarn
|
||||||
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
|
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||||
@@ -354,6 +343,45 @@ class ModelConfig:
|
|||||||
)
|
)
|
||||||
self.vocab_size = self.hf_text_config.vocab_size
|
self.vocab_size = self.hf_text_config.vocab_size
|
||||||
|
|
||||||
|
# Verify quantization
|
||||||
|
self._verify_quantization()
|
||||||
|
|
||||||
|
# Verify dual-chunk attention config
|
||||||
|
self._verify_dual_chunk_attention_config()
|
||||||
|
|
||||||
|
# Cache attributes
|
||||||
|
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||||
|
|
||||||
|
# multimodal
|
||||||
|
self.image_token_id = getattr(
|
||||||
|
self.hf_config, "image_token_id", None
|
||||||
|
) or getattr(self.hf_config, "image_token_index", None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_server_args(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
model_path: str = None,
|
||||||
|
model_revision: str = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return ModelConfig(
|
||||||
|
model_path=model_path or server_args.model_path,
|
||||||
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
revision=model_revision or server_args.revision,
|
||||||
|
context_length=server_args.context_length,
|
||||||
|
model_override_args=server_args.json_model_override_args,
|
||||||
|
is_embedding=server_args.is_embedding,
|
||||||
|
enable_multimodal=server_args.enable_multimodal,
|
||||||
|
dtype=server_args.dtype,
|
||||||
|
quantization=server_args.quantization,
|
||||||
|
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||||
|
model_impl=server_args.model_impl,
|
||||||
|
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
|
||||||
|
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
|
||||||
|
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def get_total_num_attention_heads(self) -> int:
|
def get_total_num_attention_heads(self) -> int:
|
||||||
return self.num_attention_heads
|
return self.num_attention_heads
|
||||||
|
|
||||||
@@ -454,31 +482,13 @@ class ModelConfig:
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
hf_api = HfApi()
|
hf_api = HfApi()
|
||||||
|
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
||||||
def check_hf_quant_config():
|
|
||||||
return hf_api.file_exists(
|
|
||||||
self.model_path, "hf_quant_config.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Retry HF API call up to 3 times
|
|
||||||
file_exists = retry(
|
|
||||||
check_hf_quant_config,
|
|
||||||
max_retry=2,
|
|
||||||
initial_delay=1.0,
|
|
||||||
max_delay=5.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_exists:
|
|
||||||
quant_cfg = modelopt_quant_config
|
quant_cfg = modelopt_quant_config
|
||||||
|
|
||||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Offline mode is enabled, skipping hf_quant_config.json check"
|
"Offline mode is enabled, skipping hf_quant_config.json check"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
pass
|
||||||
logger.warning(
|
|
||||||
f"Failed to check hf_quant_config.json: {self.model_path} {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
||||||
quant_config_file = os.path.join(
|
quant_config_file = os.path.join(
|
||||||
@@ -606,7 +616,7 @@ class ModelConfig:
|
|||||||
"sparse_attention_enabled"
|
"sparse_attention_enabled"
|
||||||
] = True
|
] = True
|
||||||
|
|
||||||
def _get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
||||||
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
||||||
if eos_ids is not None:
|
if eos_ids is not None:
|
||||||
# it can be either int or list of int
|
# it can be either int or list of int
|
||||||
@@ -626,7 +636,7 @@ class ModelConfig:
|
|||||||
eos_ids = eos_ids | generation_eos_ids
|
eos_ids = eos_ids | generation_eos_ids
|
||||||
return eos_ids
|
return eos_ids
|
||||||
|
|
||||||
def _maybe_pull_model_tokenizer_from_remote(self) -> None:
|
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
||||||
"""
|
"""
|
||||||
Pull the model config files to a temporary
|
Pull the model config files to a temporary
|
||||||
directory in case of remote.
|
directory in case of remote.
|
||||||
@@ -769,8 +779,6 @@ multimodal_model_archs = [
|
|||||||
"Qwen2AudioForConditionalGeneration",
|
"Qwen2AudioForConditionalGeneration",
|
||||||
"Qwen2VLForConditionalGeneration",
|
"Qwen2VLForConditionalGeneration",
|
||||||
"Qwen2_5_VLForConditionalGeneration",
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
"Qwen3VLForConditionalGeneration",
|
|
||||||
"Qwen3VLMoeForConditionalGeneration",
|
|
||||||
"KimiVLForConditionalGeneration",
|
"KimiVLForConditionalGeneration",
|
||||||
"InternVLChatModel",
|
"InternVLChatModel",
|
||||||
"InternS1ForConditionalGeneration",
|
"InternS1ForConditionalGeneration",
|
||||||
|
|||||||
@@ -1,586 +0,0 @@
|
|||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
from transformers.modeling_rope_utils import rope_config_validation
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3VLVisionConfig(PretrainedConfig):
|
|
||||||
model_type = "qwen3_vl"
|
|
||||||
base_config_key = "vision_config"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
depth=27,
|
|
||||||
hidden_size=1152,
|
|
||||||
hidden_act="gelu_pytorch_tanh",
|
|
||||||
intermediate_size=4304,
|
|
||||||
num_heads=16,
|
|
||||||
in_channels=3,
|
|
||||||
patch_size=16,
|
|
||||||
spatial_merge_size=2,
|
|
||||||
temporal_patch_size=2,
|
|
||||||
out_hidden_size=3584,
|
|
||||||
num_position_embeddings=2304,
|
|
||||||
deepstack_visual_indexes=[8, 16, 24],
|
|
||||||
initializer_range=0.02,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.depth = depth
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.spatial_merge_size = spatial_merge_size
|
|
||||||
self.temporal_patch_size = temporal_patch_size
|
|
||||||
self.out_hidden_size = out_hidden_size
|
|
||||||
self.num_position_embeddings = num_position_embeddings
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.deepstack_visual_indexes = deepstack_visual_indexes
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3VLTextConfig(PretrainedConfig):
|
|
||||||
r"""
|
|
||||||
This is the configuration class to store the configuration of a [`Qwen3VLTextModel`]. It is used to instantiate a
|
|
||||||
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
||||||
with the defaults will yield a similar configuration to that of
|
|
||||||
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
|
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab_size (`int`, *optional*, defaults to 151936):
|
|
||||||
Vocabulary size of the Qwen3VL model. Defines the number of different tokens that can be represented by the
|
|
||||||
`inputs_ids` passed when calling [`Qwen3VLModel`]
|
|
||||||
hidden_size (`int`, *optional*, defaults to 4096):
|
|
||||||
Dimension of the hidden representations.
|
|
||||||
intermediate_size (`int`, *optional*, defaults to 22016):
|
|
||||||
Dimension of the MLP representations.
|
|
||||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
|
||||||
Number of hidden layers in the Transformer encoder.
|
|
||||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
|
||||||
Number of attention heads for each attention layer in the Transformer encoder.
|
|
||||||
num_key_value_heads (`int`, *optional*, defaults to 32):
|
|
||||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
||||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
||||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
||||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
||||||
by meanpooling all the original heads within that group. For more details, check out [this
|
|
||||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
|
||||||
head_dim (`int`, *optional*, defaults to 128):
|
|
||||||
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
|
|
||||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
||||||
The non-linear activation function (function or string) in the decoder.
|
|
||||||
max_position_embeddings (`int`, *optional*, defaults to 128000):
|
|
||||||
The maximum sequence length that this model might ever be used with.
|
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
||||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
||||||
The epsilon used by the rms normalization layers.
|
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
||||||
relevant if `config.is_decoder=True`.
|
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether the model's input and output word embeddings should be tied.
|
|
||||||
rope_theta (`float`, *optional*, defaults to 5000000.0):
|
|
||||||
The base period of the RoPE embeddings.
|
|
||||||
rope_scaling (`Dict`, *optional*):
|
|
||||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
|
||||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
|
||||||
accordingly.
|
|
||||||
Expected contents:
|
|
||||||
`rope_type` (`str`):
|
|
||||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
|
||||||
'llama3'], with 'default' being the original RoPE implementation.
|
|
||||||
`factor` (`float`, *optional*):
|
|
||||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
|
||||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
|
||||||
original maximum pre-trained length.
|
|
||||||
`original_max_position_embeddings` (`int`, *optional*):
|
|
||||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
|
||||||
pretraining.
|
|
||||||
`attention_factor` (`float`, *optional*):
|
|
||||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
|
||||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
|
||||||
`factor` field to infer the suggested value.
|
|
||||||
`beta_fast` (`float`, *optional*):
|
|
||||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
|
||||||
ramp function. If unspecified, it defaults to 32.
|
|
||||||
`beta_slow` (`float`, *optional*):
|
|
||||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
|
||||||
ramp function. If unspecified, it defaults to 1.
|
|
||||||
`short_factor` (`list[float]`, *optional*):
|
|
||||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
|
||||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
|
||||||
size divided by the number of attention heads divided by 2
|
|
||||||
`long_factor` (`list[float]`, *optional*):
|
|
||||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
|
||||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
|
||||||
size divided by the number of attention heads divided by 2
|
|
||||||
`low_freq_factor` (`float`, *optional*):
|
|
||||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
|
||||||
`high_freq_factor` (`float`, *optional*):
|
|
||||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
|
||||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
||||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
||||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
||||||
The dropout ratio for the attention probabilities.
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import Qwen3VLTextModel, Qwen3VLTextConfig
|
|
||||||
|
|
||||||
>>> # Initializing a Qwen3VL style configuration
|
|
||||||
>>> configuration = Qwen3VLTextConfig()
|
|
||||||
|
|
||||||
>>> # Initializing a model from the Qwen3-VL-7B style configuration
|
|
||||||
>>> model = Qwen3VLTextModel(configuration)
|
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
|
||||||
>>> configuration = model.config
|
|
||||||
```"""
|
|
||||||
|
|
||||||
model_type = "qwen3_vl_text"
|
|
||||||
base_config_key = "text_config"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=151936,
|
|
||||||
hidden_size=4096,
|
|
||||||
intermediate_size=22016,
|
|
||||||
num_hidden_layers=32,
|
|
||||||
num_attention_heads=32,
|
|
||||||
num_key_value_heads=32,
|
|
||||||
head_dim=128,
|
|
||||||
hidden_act="silu",
|
|
||||||
max_position_embeddings=128000,
|
|
||||||
initializer_range=0.02,
|
|
||||||
rms_norm_eps=1e-6,
|
|
||||||
use_cache=True,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
rope_theta=5000000.0,
|
|
||||||
rope_scaling=None,
|
|
||||||
attention_bias=False,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
|
|
||||||
# for backward compatibility
|
|
||||||
if num_key_value_heads is None:
|
|
||||||
num_key_value_heads = num_attention_heads
|
|
||||||
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.rms_norm_eps = rms_norm_eps
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.attention_bias = attention_bias
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
|
|
||||||
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
|
|
||||||
|
|
||||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3VLConfig(PretrainedConfig):
|
|
||||||
r"""
|
|
||||||
This is the configuration class to store the configuration of a [`Qwen3VLModel`]. It is used to instantiate a
|
|
||||||
Qwen3-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
||||||
with the defaults will yield a similar configuration to that of
|
|
||||||
Qwen3-VL-4B-Instruct [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).
|
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLTextConfig`):
|
|
||||||
The config object or dictionary of the text backbone.
|
|
||||||
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLVisionConfig`):
|
|
||||||
The config object or dictionary of the vision backbone.
|
|
||||||
image_token_id (`int`, *optional*, defaults to 151655):
|
|
||||||
The image token index to encode the image prompt.
|
|
||||||
video_token_id (`int`, *optional*, defaults to 151656):
|
|
||||||
The video token index to encode the image prompt.
|
|
||||||
vision_start_token_id (`int`, *optional*, defaults to 151652):
|
|
||||||
The start token index to encode the image prompt.
|
|
||||||
vision_end_token_id (`int`, *optional*, defaults to 151653):
|
|
||||||
The end token index to encode the image prompt.
|
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to tie the word embeddings.
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import Qwen3VLForConditionalGeneration, Qwen3VLConfig
|
|
||||||
|
|
||||||
>>> # Initializing a Qwen3-VL style configuration
|
|
||||||
>>> configuration = Qwen3VLConfig()
|
|
||||||
|
|
||||||
>>> # Initializing a model from the Qwen3-VL-4B style configuration
|
|
||||||
>>> model = Qwen3VLForConditionalGeneration(configuration)
|
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
|
||||||
>>> configuration = model.config
|
|
||||||
```"""
|
|
||||||
|
|
||||||
model_type = "qwen3_vl"
|
|
||||||
sub_configs = {
|
|
||||||
"vision_config": Qwen3VLVisionConfig,
|
|
||||||
"text_config": Qwen3VLTextConfig,
|
|
||||||
}
|
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_config=None,
|
|
||||||
vision_config=None,
|
|
||||||
image_token_id=151655,
|
|
||||||
video_token_id=151656,
|
|
||||||
vision_start_token_id=151652,
|
|
||||||
vision_end_token_id=151653,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if isinstance(vision_config, dict):
|
|
||||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
|
||||||
elif vision_config is None:
|
|
||||||
self.vision_config = self.sub_configs["vision_config"]()
|
|
||||||
|
|
||||||
if isinstance(text_config, dict):
|
|
||||||
self.text_config = self.sub_configs["text_config"](**text_config)
|
|
||||||
elif text_config is None:
|
|
||||||
self.text_config = self.sub_configs["text_config"]()
|
|
||||||
|
|
||||||
self.image_token_id = image_token_id
|
|
||||||
self.video_token_id = video_token_id
|
|
||||||
self.vision_start_token_id = vision_start_token_id
|
|
||||||
self.vision_end_token_id = vision_end_token_id
|
|
||||||
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3VLMoeTextConfig(PretrainedConfig):
|
|
||||||
r"""
|
|
||||||
This is the configuration class to store the configuration of a [`Qwen3VLMoeTextModel`]. It is used to instantiate a
|
|
||||||
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
||||||
with the defaults will yield a similar configuration to that of
|
|
||||||
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
|
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab_size (`int`, *optional*, defaults to 151936):
|
|
||||||
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
|
|
||||||
`inputs_ids` passed when calling [`Qwen2MoeModel`]
|
|
||||||
hidden_size (`int`, *optional*, defaults to 2048):
|
|
||||||
Dimension of the hidden representations.
|
|
||||||
intermediate_size (`int`, *optional*, defaults to 5632):
|
|
||||||
Dimension of the MLP representations.
|
|
||||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
|
||||||
Number of hidden layers in the Transformer encoder.
|
|
||||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
|
||||||
Number of attention heads for each attention layer in the Transformer encoder.
|
|
||||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
|
||||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
||||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
||||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
||||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
||||||
by meanpooling all the original heads within that group. For more details checkout [this
|
|
||||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
|
||||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
|
||||||
The non-linear activation function (function or string) in the decoder.
|
|
||||||
max_position_embeddings (`int`, *optional*, defaults to 128000):
|
|
||||||
The maximum sequence length that this model might ever be used with.
|
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
||||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
||||||
The epsilon used by the rms normalization layers.
|
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
||||||
relevant if `config.is_decoder=True`.
|
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether the model's input and output word embeddings should be tied.
|
|
||||||
rope_theta (`float`, *optional*, defaults to 5000000.0):
|
|
||||||
The base period of the RoPE embeddings.
|
|
||||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
||||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
||||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
||||||
The dropout ratio for the attention probabilities.
|
|
||||||
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
|
||||||
The frequency of the MoE layer.
|
|
||||||
moe_intermediate_size (`int`, *optional*, defaults to 1408):
|
|
||||||
Intermediate size of the routed expert.
|
|
||||||
num_experts_per_tok (`int`, *optional*, defaults to 4):
|
|
||||||
Number of selected experts.
|
|
||||||
num_experts (`int`, *optional*, defaults to 60):
|
|
||||||
Number of routed experts.
|
|
||||||
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to normalize the topk probabilities.
|
|
||||||
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
|
|
||||||
Indicate which layers use Qwen3VLMoeMLP rather than Qwen3VLMoeSparseMoeBlock
|
|
||||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
|
||||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
|
||||||
rope_scaling (`Dict`, *optional*):
|
|
||||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
|
||||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
|
||||||
accordingly.
|
|
||||||
Expected contents:
|
|
||||||
`rope_type` (`str`):
|
|
||||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
|
||||||
'llama3'], with 'default' being the original RoPE implementation.
|
|
||||||
`factor` (`float`, *optional*):
|
|
||||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
|
||||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
|
||||||
original maximum pre-trained length.
|
|
||||||
`original_max_position_embeddings` (`int`, *optional*):
|
|
||||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
|
||||||
pretraining.
|
|
||||||
`attention_factor` (`float`, *optional*):
|
|
||||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
|
||||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
|
||||||
`factor` field to infer the suggested value.
|
|
||||||
`beta_fast` (`float`, *optional*):
|
|
||||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
|
||||||
ramp function. If unspecified, it defaults to 32.
|
|
||||||
`beta_slow` (`float`, *optional*):
|
|
||||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
|
||||||
ramp function. If unspecified, it defaults to 1.
|
|
||||||
`short_factor` (`List[float]`, *optional*):
|
|
||||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
|
||||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
|
||||||
size divided by the number of attention heads divided by 2
|
|
||||||
`long_factor` (`List[float]`, *optional*):
|
|
||||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
|
||||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
|
||||||
size divided by the number of attention heads divided by 2
|
|
||||||
`low_freq_factor` (`float`, *optional*):
|
|
||||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
|
||||||
`high_freq_factor` (`float`, *optional*):
|
|
||||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
|
||||||
head_dim (`int`, *optional*):
|
|
||||||
The dimension of the head. If not specified, will default to `hidden_size // num_attention_heads`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
|
|
||||||
|
|
||||||
>>> # Initializing a Qwen3VLMoe style configuration
|
|
||||||
>>> configuration = Qwen3VLMoeConfig()
|
|
||||||
|
|
||||||
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
|
|
||||||
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
|
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
|
||||||
>>> configuration = model.config
|
|
||||||
```"""
|
|
||||||
|
|
||||||
model_type = "qwen3_vl_moe_text"
|
|
||||||
base_config_key = "text_config"
|
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
|
||||||
# Default tensor parallel plan for base model `Qwen3VLMoe`
|
|
||||||
base_model_tp_plan = {
|
|
||||||
"layers.*.self_attn.q_proj": "colwise",
|
|
||||||
"layers.*.self_attn.k_proj": "colwise",
|
|
||||||
"layers.*.self_attn.v_proj": "colwise",
|
|
||||||
"layers.*.self_attn.o_proj": "rowwise",
|
|
||||||
"layers.*.mlp.gate_proj": "colwise",
|
|
||||||
"layers.*.mlp.up_proj": "colwise",
|
|
||||||
"layers.*.mlp.down_proj": "rowwise",
|
|
||||||
}
|
|
||||||
base_model_pp_plan = {
|
|
||||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
|
||||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
|
||||||
"norm": (["hidden_states"], ["hidden_states"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=151936,
|
|
||||||
hidden_size=2048,
|
|
||||||
intermediate_size=5632,
|
|
||||||
num_hidden_layers=24,
|
|
||||||
num_attention_heads=16,
|
|
||||||
num_key_value_heads=16,
|
|
||||||
hidden_act="silu",
|
|
||||||
max_position_embeddings=128000,
|
|
||||||
initializer_range=0.02,
|
|
||||||
rms_norm_eps=1e-6,
|
|
||||||
use_cache=True,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
rope_theta=5000000.0,
|
|
||||||
attention_bias=False,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
decoder_sparse_step=1,
|
|
||||||
moe_intermediate_size=1408,
|
|
||||||
num_experts_per_tok=4,
|
|
||||||
num_experts=60,
|
|
||||||
norm_topk_prob=True,
|
|
||||||
mlp_only_layers=None,
|
|
||||||
rope_scaling=None,
|
|
||||||
head_dim=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
|
|
||||||
# for backward compatibility
|
|
||||||
if num_key_value_heads is None:
|
|
||||||
num_key_value_heads = num_attention_heads
|
|
||||||
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.rms_norm_eps = rms_norm_eps
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
self.attention_bias = attention_bias
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
|
||||||
|
|
||||||
rope_config_validation(self, ignore_keys={"mrope_section", "mrope_interleaved"})
|
|
||||||
|
|
||||||
# MoE arguments
|
|
||||||
self.decoder_sparse_step = decoder_sparse_step
|
|
||||||
self.moe_intermediate_size = moe_intermediate_size
|
|
||||||
self.num_experts_per_tok = num_experts_per_tok
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.norm_topk_prob = norm_topk_prob
|
|
||||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
|
||||||
|
|
||||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3VLMoeVisionConfig(PretrainedConfig):
|
|
||||||
model_type = "qwen3_vl_moe"
|
|
||||||
base_config_key = "vision_config"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
depth=27,
|
|
||||||
hidden_size=1152,
|
|
||||||
hidden_act="gelu_pytorch_tanh",
|
|
||||||
intermediate_size=4304,
|
|
||||||
num_heads=16,
|
|
||||||
in_channels=3,
|
|
||||||
patch_size=16,
|
|
||||||
spatial_merge_size=2,
|
|
||||||
temporal_patch_size=2,
|
|
||||||
out_hidden_size=3584,
|
|
||||||
num_position_embeddings=2304,
|
|
||||||
deepstack_visual_indexes=[8, 16, 24],
|
|
||||||
initializer_range=0.02,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.depth = depth
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.spatial_merge_size = spatial_merge_size
|
|
||||||
self.temporal_patch_size = temporal_patch_size
|
|
||||||
self.out_hidden_size = out_hidden_size
|
|
||||||
self.num_position_embeddings = num_position_embeddings
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.deepstack_visual_indexes = deepstack_visual_indexes
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3VLMoeConfig(PretrainedConfig):
|
|
||||||
r"""
|
|
||||||
This is the configuration class to store the configuration of a [`Qwen3VLMoeModel`]. It is used to instantiate a
|
|
||||||
Qwen3-VL-MOE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
||||||
with the defaults will yield a similar configuration to that of
|
|
||||||
Qwen3-VL-30B-A3B-Instruct [Qwen/Qwen3-VL-30B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-30B-A3B-Instruct).
|
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeTextConfig`):
|
|
||||||
The config object or dictionary of the text backbone.
|
|
||||||
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Qwen3VLMoeVisionConfig`):
|
|
||||||
The config object or dictionary of the vision backbone.
|
|
||||||
image_token_id (`int`, *optional*, defaults to 151655):
|
|
||||||
The image token index to encode the image prompt.
|
|
||||||
video_token_id (`int`, *optional*, defaults to 151656):
|
|
||||||
The video token index to encode the image prompt.
|
|
||||||
vision_start_token_id (`int`, *optional*, defaults to 151652):
|
|
||||||
The start token index to encode the image prompt.
|
|
||||||
vision_end_token_id (`int`, *optional*, defaults to 151653):
|
|
||||||
The end token index to encode the image prompt.
|
|
||||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to tie the word embeddings.
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeConfig
|
|
||||||
|
|
||||||
>>> # Initializing a Qwen3-VL-MOE style configuration
|
|
||||||
>>> configuration = Qwen3VLMoeConfig()
|
|
||||||
|
|
||||||
>>> # Initializing a model from the Qwen3-VL-30B-A3B style configuration
|
|
||||||
>>> model = Qwen3VLMoeForConditionalGeneration(configuration)
|
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
|
||||||
>>> configuration = model.config
|
|
||||||
```"""
|
|
||||||
|
|
||||||
model_type = "qwen3_vl_moe"
|
|
||||||
sub_configs = {
|
|
||||||
"vision_config": Qwen3VLMoeVisionConfig,
|
|
||||||
"text_config": Qwen3VLMoeTextConfig,
|
|
||||||
}
|
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_config=None,
|
|
||||||
vision_config=None,
|
|
||||||
image_token_id=151655,
|
|
||||||
video_token_id=151656,
|
|
||||||
vision_start_token_id=151652,
|
|
||||||
vision_end_token_id=151653,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if isinstance(vision_config, dict):
|
|
||||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
|
||||||
elif vision_config is None:
|
|
||||||
self.vision_config = self.sub_configs["vision_config"]()
|
|
||||||
|
|
||||||
if isinstance(text_config, dict):
|
|
||||||
self.text_config = self.sub_configs["text_config"](**text_config)
|
|
||||||
elif text_config is None:
|
|
||||||
self.text_config = self.sub_configs["text_config"]()
|
|
||||||
|
|
||||||
self.image_token_id = image_token_id
|
|
||||||
self.video_token_id = video_token_id
|
|
||||||
self.vision_start_token_id = vision_start_token_id
|
|
||||||
self.vision_end_token_id = vision_end_token_id
|
|
||||||
super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Qwen3VLMoeConfig",
|
|
||||||
"Qwen3VLMoeVisionConfig",
|
|
||||||
"Qwen3VLConfig",
|
|
||||||
"Qwen3VLVisionConfig",
|
|
||||||
]
|
|
||||||
@@ -2,9 +2,19 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mf_adapter import TransferEngine
|
||||||
|
|
||||||
|
import_error = None
|
||||||
|
except ImportError as e:
|
||||||
|
import_error = e
|
||||||
|
pass
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -13,12 +23,11 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
|
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
|
||||||
):
|
):
|
||||||
try:
|
if import_error is not None:
|
||||||
from mf_adapter import TransferEngine
|
logger.warning(
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
|
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
|
||||||
) from e
|
)
|
||||||
|
raise import_error
|
||||||
|
|
||||||
self.engine = TransferEngine()
|
self.engine = TransferEngine()
|
||||||
self.hostname = hostname
|
self.hostname = hostname
|
||||||
@@ -37,12 +46,29 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
|||||||
self.initialize()
|
self.initialize()
|
||||||
|
|
||||||
def initialize(self) -> None:
|
def initialize(self) -> None:
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
get_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
transfer_protocol = self._get_transfer_protocol()
|
||||||
|
if transfer_protocol is None or transfer_protocol == "sdma":
|
||||||
|
trans_op_type = TransferEngine.TransDataOpType.SDMA
|
||||||
|
else:
|
||||||
|
trans_op_type = TransferEngine.TransDataOpType.DEVICE_RDMA
|
||||||
|
"""with device RDMA for PD transfer"""
|
||||||
|
tmp_tensor = torch.zeros(1, device="npu")
|
||||||
|
output_tensor_list = [
|
||||||
|
torch.empty_like(tmp_tensor)
|
||||||
|
for _ in range(get_tensor_model_parallel_world_size())
|
||||||
|
]
|
||||||
|
# Initialize hccl in advance through all_gather to avoid conflicts with rdma initialization.
|
||||||
|
torch.distributed.all_gather(
|
||||||
|
output_tensor_list, tmp_tensor, group=get_tp_group().device_group
|
||||||
|
)
|
||||||
"""Initialize the ascend transfer instance."""
|
"""Initialize the ascend transfer instance."""
|
||||||
ret_value = self.engine.initialize(
|
ret_value = self.engine.initialize(
|
||||||
self.store_url,
|
self.store_url, self.session_id, self.role, self.npu_id, trans_op_type
|
||||||
self.session_id,
|
|
||||||
self.role,
|
|
||||||
self.npu_id,
|
|
||||||
)
|
)
|
||||||
if ret_value != 0:
|
if ret_value != 0:
|
||||||
logger.error("Ascend Transfer Engine initialization failed.")
|
logger.error("Ascend Transfer Engine initialization failed.")
|
||||||
@@ -56,3 +82,15 @@ class AscendTransferEngine(MooncakeTransferEngine):
|
|||||||
ret_value = -1
|
ret_value = -1
|
||||||
if ret_value != 0:
|
if ret_value != 0:
|
||||||
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
|
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_transfer_protocol():
|
||||||
|
protocol = os.getenv("ASCEND_MF_TRANSFER_PROTOCOL")
|
||||||
|
allowed_protocols = {"device_rdma", "sdma"}
|
||||||
|
if protocol and protocol.lower() in allowed_protocols:
|
||||||
|
return protocol.lower()
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Invalid or no transfer protocol specified, using default protocol."
|
||||||
|
)
|
||||||
|
return None
|
||||||
@@ -95,6 +95,14 @@ class CommonKVManager(BaseKVManager):
|
|||||||
def _bind_server_socket(self):
|
def _bind_server_socket(self):
|
||||||
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
||||||
|
socket = zmq.Context().socket(zmq.PUSH)
|
||||||
|
if is_ipv6:
|
||||||
|
socket.setsockopt(zmq.IPV6, 1)
|
||||||
|
socket.connect(endpoint)
|
||||||
|
return socket
|
||||||
|
|
||||||
def _register_to_bootstrap(self):
|
def _register_to_bootstrap(self):
|
||||||
"""Register KVSender to bootstrap server via HTTP POST."""
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
||||||
if self.dist_init_addr:
|
if self.dist_init_addr:
|
||||||
@@ -148,33 +156,6 @@ class CommonKVManager(BaseKVManager):
|
|||||||
socket.connect(endpoint)
|
socket.connect(endpoint)
|
||||||
return socket
|
return socket
|
||||||
|
|
||||||
def get_mha_kv_ptrs_with_pp(
|
|
||||||
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
|
||||||
) -> Tuple[List[int], List[int], List[int], List[int], int]:
|
|
||||||
# pp is not supported on the decode side yet
|
|
||||||
start_layer = self.kv_args.prefill_start_layer
|
|
||||||
num_kv_layers = len(src_kv_ptrs) // 2
|
|
||||||
end_layer = start_layer + num_kv_layers
|
|
||||||
dst_num_total_layers = len(dst_kv_ptrs) // 2
|
|
||||||
src_k_ptrs = src_kv_ptrs[:num_kv_layers]
|
|
||||||
src_v_ptrs = src_kv_ptrs[num_kv_layers:]
|
|
||||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
|
||||||
dst_v_ptrs = dst_kv_ptrs[
|
|
||||||
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
|
||||||
]
|
|
||||||
layers_current_pp_stage = len(src_k_ptrs)
|
|
||||||
return src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage
|
|
||||||
|
|
||||||
def get_mla_kv_ptrs_with_pp(
|
|
||||||
self, src_kv_ptrs: List[int], dst_kv_ptrs: List[int]
|
|
||||||
) -> Tuple[List[int], List[int], int]:
|
|
||||||
# pp is not supported on the decode side yet
|
|
||||||
start_layer = self.kv_args.prefill_start_layer
|
|
||||||
end_layer = start_layer + len(src_kv_ptrs)
|
|
||||||
sliced_dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
|
||||||
layers_current_pp_stage = len(src_kv_ptrs)
|
|
||||||
return src_kv_ptrs, sliced_dst_kv_ptrs, layers_current_pp_stage
|
|
||||||
|
|
||||||
|
|
||||||
class CommonKVSender(BaseKVSender):
|
class CommonKVSender(BaseKVSender):
|
||||||
|
|
||||||
|
|||||||
@@ -609,21 +609,15 @@ class DecodeTransferQueue:
|
|||||||
idx = decode_req.metadata_buffer_index
|
idx = decode_req.metadata_buffer_index
|
||||||
(
|
(
|
||||||
output_id,
|
output_id,
|
||||||
cached_tokens,
|
|
||||||
output_token_logprobs_val,
|
output_token_logprobs_val,
|
||||||
output_token_logprobs_idx,
|
output_token_logprobs_idx,
|
||||||
output_top_logprobs_val,
|
output_top_logprobs_val,
|
||||||
output_top_logprobs_idx,
|
output_top_logprobs_idx,
|
||||||
output_topk_p,
|
|
||||||
output_topk_index,
|
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
) = self.metadata_buffers.get_buf(idx)
|
) = self.metadata_buffers.get_buf(idx)
|
||||||
|
|
||||||
decode_req.req.output_ids.append(output_id[0].item())
|
decode_req.req.output_ids.append(output_id[0].item())
|
||||||
decode_req.req.cached_tokens = cached_tokens[0].item()
|
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
decode_req.req.output_topk_p = output_topk_p
|
|
||||||
decode_req.req.output_topk_index = output_topk_index
|
|
||||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||||
if decode_req.req.return_logprob:
|
if decode_req.req.return_logprob:
|
||||||
decode_req.req.output_token_logprobs_val.append(
|
decode_req.req.output_token_logprobs_val.append(
|
||||||
@@ -713,15 +707,12 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
elif prepare_mlp_sync_flag:
|
elif prepare_mlp_sync_flag:
|
||||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||||
|
|
||||||
queue_size = (
|
if batch is None and (
|
||||||
len(self.waiting_queue)
|
len(self.waiting_queue)
|
||||||
+ len(self.disagg_decode_transfer_queue.queue)
|
+ len(self.disagg_decode_transfer_queue.queue)
|
||||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
)
|
== 0
|
||||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
):
|
||||||
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
|
||||||
|
|
||||||
if batch is None and queue_size == 0:
|
|
||||||
self.self_check_during_idle()
|
self.self_check_during_idle()
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
@@ -790,15 +781,12 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
)
|
)
|
||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
|
|
||||||
queue_size = (
|
if batch is None and (
|
||||||
len(self.waiting_queue)
|
len(self.waiting_queue)
|
||||||
+ len(self.disagg_decode_transfer_queue.queue)
|
+ len(self.disagg_decode_transfer_queue.queue)
|
||||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
)
|
== 0
|
||||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
):
|
||||||
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
|
||||||
|
|
||||||
if batch is None and queue_size == 0:
|
|
||||||
self.self_check_during_idle()
|
self.self_check_during_idle()
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
@@ -917,6 +905,3 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.disagg_decode_transfer_queue.pop_transferred()
|
self.disagg_decode_transfer_queue.pop_transferred()
|
||||||
) # the requests which kv has arrived
|
) # the requests which kv has arrived
|
||||||
self.waiting_queue.extend(alloc_reqs)
|
self.waiting_queue.extend(alloc_reqs)
|
||||||
|
|
||||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
|
||||||
self.decode_offload_manager.check_offload_progress()
|
|
||||||
|
|||||||
@@ -1,185 +0,0 @@
|
|||||||
import logging
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang import ServerArgs
|
|
||||||
from sglang.srt.managers.cache_controller import HiCacheController
|
|
||||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
|
||||||
MHATokenToKVPool,
|
|
||||||
MLATokenToKVPool,
|
|
||||||
ReqToTokenPool,
|
|
||||||
)
|
|
||||||
from sglang.srt.mem_cache.memory_pool_host import (
|
|
||||||
MHATokenToKVPoolHost,
|
|
||||||
MLATokenToKVPoolHost,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DecodeKVCacheOffloadManager:
|
|
||||||
"""Manage decode-side KV cache offloading lifecycle and operations."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
req_to_token_pool: ReqToTokenPool,
|
|
||||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
|
||||||
tp_group: torch.distributed.ProcessGroup,
|
|
||||||
tree_cache: BasePrefixCache,
|
|
||||||
server_args: ServerArgs,
|
|
||||||
) -> None:
|
|
||||||
self.req_to_token_pool = req_to_token_pool
|
|
||||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
|
||||||
self.page_size = server_args.page_size
|
|
||||||
self.server_args = server_args
|
|
||||||
self.request_counter = 0
|
|
||||||
self.tree_cache = tree_cache
|
|
||||||
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
|
|
||||||
if isinstance(kv_cache, MHATokenToKVPool):
|
|
||||||
self.decode_host_mem_pool = MHATokenToKVPoolHost(
|
|
||||||
kv_cache,
|
|
||||||
server_args.hicache_ratio,
|
|
||||||
server_args.hicache_size,
|
|
||||||
self.page_size,
|
|
||||||
server_args.hicache_mem_layout,
|
|
||||||
)
|
|
||||||
elif isinstance(kv_cache, MLATokenToKVPool):
|
|
||||||
self.decode_host_mem_pool = MLATokenToKVPoolHost(
|
|
||||||
kv_cache,
|
|
||||||
server_args.hicache_ratio,
|
|
||||||
server_args.hicache_size,
|
|
||||||
self.page_size,
|
|
||||||
server_args.hicache_mem_layout,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported KV cache type for decode offload")
|
|
||||||
|
|
||||||
self.tp_group = tp_group
|
|
||||||
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
|
||||||
self.cache_controller = HiCacheController(
|
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
||||||
mem_pool_host=self.decode_host_mem_pool,
|
|
||||||
page_size=self.page_size,
|
|
||||||
tp_group=tp_group,
|
|
||||||
io_backend=server_args.hicache_io_backend,
|
|
||||||
load_cache_event=threading.Event(),
|
|
||||||
storage_backend=server_args.hicache_storage_backend,
|
|
||||||
model_name=server_args.served_model_name,
|
|
||||||
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ongoing_offload = {}
|
|
||||||
self.ongoing_backup = {}
|
|
||||||
logger.info("Enable offload kv cache for decode side")
|
|
||||||
|
|
||||||
def offload_kv_cache(self, req) -> bool:
|
|
||||||
"""Offload a finished request's KV cache to storage."""
|
|
||||||
|
|
||||||
if self.cache_controller is None or self.decode_host_mem_pool is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if req.req_pool_idx == -1:
|
|
||||||
return False
|
|
||||||
|
|
||||||
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
|
|
||||||
if token_indices.dim() == 0 or token_indices.numel() == 0:
|
|
||||||
logger.debug(
|
|
||||||
f"Request {req.rid} has invalid token_indices: {token_indices}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
tokens = req.origin_input_ids + req.output_ids
|
|
||||||
aligned_len = (len(tokens) // self.page_size) * self.page_size
|
|
||||||
if aligned_len == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
token_indices = token_indices[:aligned_len]
|
|
||||||
tokens = tokens[:aligned_len]
|
|
||||||
|
|
||||||
# Asynchronously offload KV cache from device to host by cache controller
|
|
||||||
self.request_counter += 1
|
|
||||||
ack_id = self.request_counter
|
|
||||||
host_indices = self.cache_controller.write(
|
|
||||||
device_indices=token_indices.long(),
|
|
||||||
node_id=ack_id,
|
|
||||||
)
|
|
||||||
if host_indices is None:
|
|
||||||
logger.error(f"Not enough host memory for request {req.rid}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
|
|
||||||
return True
|
|
||||||
|
|
||||||
def check_offload_progress(self):
|
|
||||||
"""Check the progress of offload from device to host and backup from host to storage."""
|
|
||||||
cc = self.cache_controller
|
|
||||||
|
|
||||||
qsizes = torch.tensor(
|
|
||||||
[
|
|
||||||
len(cc.ack_write_queue),
|
|
||||||
cc.ack_backup_queue.qsize(),
|
|
||||||
],
|
|
||||||
dtype=torch.int,
|
|
||||||
)
|
|
||||||
if self.tp_world_size > 1:
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
|
||||||
)
|
|
||||||
|
|
||||||
n_write, n_backup = map(int, qsizes.tolist())
|
|
||||||
self._check_offload_progress(n_write)
|
|
||||||
self._check_backup_progress(n_backup)
|
|
||||||
|
|
||||||
def _check_offload_progress(self, finish_count):
|
|
||||||
"""Check the progress of offload from device to host."""
|
|
||||||
while finish_count > 0:
|
|
||||||
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
|
||||||
finish_event.synchronize()
|
|
||||||
for ack_id in ack_list:
|
|
||||||
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
|
|
||||||
|
|
||||||
# Release device
|
|
||||||
self.tree_cache.cache_finished_req(req)
|
|
||||||
|
|
||||||
# Trigger async backup from host to storage by cache controller
|
|
||||||
self._trigger_backup(req.rid, host_indices, tokens, start_time)
|
|
||||||
finish_count -= 1
|
|
||||||
|
|
||||||
def _check_backup_progress(self, finish_count):
|
|
||||||
"""Check the progress of backup from host to storage."""
|
|
||||||
for _ in range(finish_count):
|
|
||||||
storage_operation = self.cache_controller.ack_backup_queue.get()
|
|
||||||
ack_id = storage_operation.id
|
|
||||||
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
|
|
||||||
|
|
||||||
# Release host memory
|
|
||||||
self.decode_host_mem_pool.free(host_indices)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
|
|
||||||
"""Trigger async backup from host to storage by cache controller."""
|
|
||||||
|
|
||||||
# Generate page hashes and write to storage
|
|
||||||
page_hashes = self._compute_prefix_hash(tokens)
|
|
||||||
ack_id = self.cache_controller.write_storage(
|
|
||||||
host_indices,
|
|
||||||
tokens,
|
|
||||||
hash_value=page_hashes,
|
|
||||||
)
|
|
||||||
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
|
|
||||||
|
|
||||||
def _compute_prefix_hash(self, tokens):
|
|
||||||
last_hash = ""
|
|
||||||
page_hashes = []
|
|
||||||
for offset in range(0, len(tokens), self.page_size):
|
|
||||||
page_tokens = tokens[offset : offset + self.page_size]
|
|
||||||
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
|
|
||||||
page_hashes.append(last_hash)
|
|
||||||
return page_hashes
|
|
||||||
@@ -125,33 +125,25 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|||||||
req.grammar.finished = req.finished()
|
req.grammar.finished = req.finished()
|
||||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||||
|
|
||||||
# Simulate the eagle run.
|
# Simulate the eagle run. We add mock data to hidden states for the
|
||||||
if self.spec_algorithm.is_eagle():
|
# ease of implementation now meaning the first token will have acc rate
|
||||||
|
# of 0.
|
||||||
|
if not self.spec_algorithm.is_none():
|
||||||
|
|
||||||
b = len(self.reqs)
|
b = len(self.reqs)
|
||||||
topk = server_args.speculative_eagle_topk
|
topk_p = torch.arange(
|
||||||
topk_p = torch.stack(
|
b * server_args.speculative_eagle_topk,
|
||||||
[
|
0,
|
||||||
torch.as_tensor(
|
-1,
|
||||||
req.output_topk_p[:topk],
|
device=self.device,
|
||||||
device=self.device,
|
dtype=torch.float32,
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
for req in self.reqs
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
)
|
||||||
topk_index = torch.stack(
|
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
|
||||||
[
|
topk_p /= b * server_args.speculative_eagle_topk
|
||||||
torch.as_tensor(
|
topk_index = torch.arange(
|
||||||
req.output_topk_index[:topk],
|
b * server_args.speculative_eagle_topk, device=self.device
|
||||||
device=self.device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
for req in self.reqs
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
)
|
)
|
||||||
|
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
||||||
|
|
||||||
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
||||||
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
||||||
|
|||||||
@@ -264,10 +264,12 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
layers_params = None
|
layers_params = None
|
||||||
|
|
||||||
# pp is not supported on the decode side yet
|
# pp is not supported on the decode side yet
|
||||||
|
start_layer = self.kv_args.prefill_start_layer
|
||||||
|
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
|
||||||
if self.is_mla_backend:
|
if self.is_mla_backend:
|
||||||
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
src_kv_ptrs = self.kv_args.kv_data_ptrs
|
||||||
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
layers_per_pp_stage = len(src_kv_ptrs)
|
||||||
)
|
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||||
layers_params = [
|
layers_params = [
|
||||||
(
|
(
|
||||||
@@ -275,12 +277,18 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
dst_kv_ptrs[layer_id],
|
dst_kv_ptrs[layer_id],
|
||||||
kv_item_len,
|
kv_item_len,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_current_pp_stage)
|
for layer_id in range(layers_per_pp_stage)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||||
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
dst_num_total_layers = num_kv_layers * self.pp_size
|
||||||
)
|
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||||
|
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||||
|
layers_per_pp_stage = len(src_k_ptrs)
|
||||||
|
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||||
|
dst_v_ptrs = dst_kv_ptrs[
|
||||||
|
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||||
|
]
|
||||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||||
layers_params = [
|
layers_params = [
|
||||||
(
|
(
|
||||||
@@ -288,14 +296,14 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
dst_k_ptrs[layer_id],
|
dst_k_ptrs[layer_id],
|
||||||
kv_item_len,
|
kv_item_len,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_current_pp_stage)
|
for layer_id in range(layers_per_pp_stage)
|
||||||
] + [
|
] + [
|
||||||
(
|
(
|
||||||
src_v_ptrs[layer_id],
|
src_v_ptrs[layer_id],
|
||||||
dst_v_ptrs[layer_id],
|
dst_v_ptrs[layer_id],
|
||||||
kv_item_len,
|
kv_item_len,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_current_pp_stage)
|
for layer_id in range(layers_per_pp_stage)
|
||||||
]
|
]
|
||||||
assert layers_params is not None
|
assert layers_params is not None
|
||||||
|
|
||||||
@@ -393,9 +401,18 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
num_heads_to_send = dst_heads_per_rank
|
num_heads_to_send = dst_heads_per_rank
|
||||||
dst_head_start_offset = 0
|
dst_head_start_offset = 0
|
||||||
|
|
||||||
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
# pp is not supported on the decode side yet
|
||||||
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||||
)
|
dst_num_total_layers = num_kv_layers * self.pp_size
|
||||||
|
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||||
|
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||||
|
layers_per_pp_stage = len(src_k_ptrs)
|
||||||
|
start_layer = self.pp_rank * layers_per_pp_stage
|
||||||
|
end_layer = start_layer + layers_per_pp_stage
|
||||||
|
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||||
|
dst_v_ptrs = dst_kv_ptrs[
|
||||||
|
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||||
|
]
|
||||||
|
|
||||||
# Calculate precise byte offset and length for the sub-slice within the token
|
# Calculate precise byte offset and length for the sub-slice within the token
|
||||||
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
||||||
@@ -421,7 +438,7 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
dst_head_slice_offset,
|
dst_head_slice_offset,
|
||||||
heads_bytes_per_token_to_send,
|
heads_bytes_per_token_to_send,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_current_pp_stage)
|
for layer_id in range(layers_per_pp_stage)
|
||||||
] + [
|
] + [
|
||||||
(
|
(
|
||||||
src_v_ptrs[layer_id],
|
src_v_ptrs[layer_id],
|
||||||
@@ -432,7 +449,7 @@ class MooncakeKVManager(CommonKVManager):
|
|||||||
dst_head_slice_offset,
|
dst_head_slice_offset,
|
||||||
heads_bytes_per_token_to_send,
|
heads_bytes_per_token_to_send,
|
||||||
)
|
)
|
||||||
for layer_id in range(layers_current_pp_stage)
|
for layer_id in range(layers_per_pp_stage)
|
||||||
]
|
]
|
||||||
|
|
||||||
def process_layer_tp_aware(layer_params):
|
def process_layer_tp_aware(layer_params):
|
||||||
|
|||||||
@@ -421,8 +421,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
last_hidden_index = (
|
last_hidden_index = (
|
||||||
hidden_state_offset + extend_input_len_per_req[i] - 1
|
hidden_state_offset + extend_input_len_per_req[i] - 1
|
||||||
)
|
)
|
||||||
req.output_topk_p = batch.spec_info.topk_p[i]
|
|
||||||
req.output_topk_index = batch.spec_info.topk_index[i]
|
|
||||||
if self.spec_algorithm.is_eagle3():
|
if self.spec_algorithm.is_eagle3():
|
||||||
req.hidden_states_tensor = (
|
req.hidden_states_tensor = (
|
||||||
batch.spec_info.hidden_states[i].cpu().clone()
|
batch.spec_info.hidden_states[i].cpu().clone()
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class MetadataBuffers:
|
|||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
hidden_states_dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
max_top_logprobs_num: int = 128,
|
max_top_logprobs_num: int = 128,
|
||||||
custom_mem_pool: torch.cuda.MemPool = None,
|
custom_mem_pool: torch.cuda.MemPool = None,
|
||||||
):
|
):
|
||||||
@@ -107,9 +107,7 @@ class MetadataBuffers:
|
|||||||
# We transfer the metadata of first output token to decode
|
# We transfer the metadata of first output token to decode
|
||||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||||
self.cached_tokens = torch.zeros(
|
|
||||||
(size, 16), dtype=torch.int32, device=device
|
|
||||||
)
|
|
||||||
self.output_token_logprobs_val = torch.zeros(
|
self.output_token_logprobs_val = torch.zeros(
|
||||||
(size, 16), dtype=torch.float32, device=device
|
(size, 16), dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
@@ -122,49 +120,33 @@ class MetadataBuffers:
|
|||||||
self.output_top_logprobs_idx = torch.zeros(
|
self.output_top_logprobs_idx = torch.zeros(
|
||||||
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
# For PD + spec decode
|
|
||||||
self.output_topk_p = torch.zeros(
|
|
||||||
(size, 16), dtype=torch.float32, device=device
|
|
||||||
)
|
|
||||||
self.output_topk_index = torch.zeros(
|
|
||||||
(size, 16), dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
self.output_hidden_states = torch.zeros(
|
self.output_hidden_states = torch.zeros(
|
||||||
(size, hidden_size), dtype=hidden_states_dtype, device=device
|
(size, hidden_size), dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_buf_infos(self):
|
def get_buf_infos(self):
|
||||||
ptrs = [
|
ptrs = [
|
||||||
self.output_ids.data_ptr(),
|
self.output_ids.data_ptr(),
|
||||||
self.cached_tokens.data_ptr(),
|
|
||||||
self.output_token_logprobs_val.data_ptr(),
|
self.output_token_logprobs_val.data_ptr(),
|
||||||
self.output_token_logprobs_idx.data_ptr(),
|
self.output_token_logprobs_idx.data_ptr(),
|
||||||
self.output_top_logprobs_val.data_ptr(),
|
self.output_top_logprobs_val.data_ptr(),
|
||||||
self.output_top_logprobs_idx.data_ptr(),
|
self.output_top_logprobs_idx.data_ptr(),
|
||||||
self.output_topk_p.data_ptr(),
|
|
||||||
self.output_topk_index.data_ptr(),
|
|
||||||
self.output_hidden_states.data_ptr(),
|
self.output_hidden_states.data_ptr(),
|
||||||
]
|
]
|
||||||
data_lens = [
|
data_lens = [
|
||||||
self.output_ids.nbytes,
|
self.output_ids.nbytes,
|
||||||
self.cached_tokens.nbytes,
|
|
||||||
self.output_token_logprobs_val.nbytes,
|
self.output_token_logprobs_val.nbytes,
|
||||||
self.output_token_logprobs_idx.nbytes,
|
self.output_token_logprobs_idx.nbytes,
|
||||||
self.output_top_logprobs_val.nbytes,
|
self.output_top_logprobs_val.nbytes,
|
||||||
self.output_top_logprobs_idx.nbytes,
|
self.output_top_logprobs_idx.nbytes,
|
||||||
self.output_topk_p.nbytes,
|
|
||||||
self.output_topk_index.nbytes,
|
|
||||||
self.output_hidden_states.nbytes,
|
self.output_hidden_states.nbytes,
|
||||||
]
|
]
|
||||||
item_lens = [
|
item_lens = [
|
||||||
self.output_ids[0].nbytes,
|
self.output_ids[0].nbytes,
|
||||||
self.cached_tokens[0].nbytes,
|
|
||||||
self.output_token_logprobs_val[0].nbytes,
|
self.output_token_logprobs_val[0].nbytes,
|
||||||
self.output_token_logprobs_idx[0].nbytes,
|
self.output_token_logprobs_idx[0].nbytes,
|
||||||
self.output_top_logprobs_val[0].nbytes,
|
self.output_top_logprobs_val[0].nbytes,
|
||||||
self.output_top_logprobs_idx[0].nbytes,
|
self.output_top_logprobs_idx[0].nbytes,
|
||||||
self.output_topk_p[0].nbytes,
|
|
||||||
self.output_topk_index[0].nbytes,
|
|
||||||
self.output_hidden_states[0].nbytes,
|
self.output_hidden_states[0].nbytes,
|
||||||
]
|
]
|
||||||
return ptrs, data_lens, item_lens
|
return ptrs, data_lens, item_lens
|
||||||
@@ -172,20 +154,16 @@ class MetadataBuffers:
|
|||||||
def get_buf(self, idx: int):
|
def get_buf(self, idx: int):
|
||||||
return (
|
return (
|
||||||
self.output_ids[idx],
|
self.output_ids[idx],
|
||||||
self.cached_tokens[idx],
|
|
||||||
self.output_token_logprobs_val[idx],
|
self.output_token_logprobs_val[idx],
|
||||||
self.output_token_logprobs_idx[idx],
|
self.output_token_logprobs_idx[idx],
|
||||||
self.output_top_logprobs_val[idx],
|
self.output_top_logprobs_val[idx],
|
||||||
self.output_top_logprobs_idx[idx],
|
self.output_top_logprobs_idx[idx],
|
||||||
self.output_topk_p[idx],
|
|
||||||
self.output_topk_index[idx],
|
|
||||||
self.output_hidden_states[idx],
|
self.output_hidden_states[idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_buf(self, req: Req):
|
def set_buf(self, req: Req):
|
||||||
|
|
||||||
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
||||||
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
if req.output_token_logprobs_val: # not none or empty list
|
if req.output_token_logprobs_val: # not none or empty list
|
||||||
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
||||||
@@ -208,17 +186,8 @@ class MetadataBuffers:
|
|||||||
] = torch.tensor(
|
] = torch.tensor(
|
||||||
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
# For PD + spec decode
|
# for PD + spec decode
|
||||||
if req.hidden_states_tensor is not None:
|
if req.hidden_states_tensor is not None:
|
||||||
# speculative_eagle_topk should not be greater than 16 currently
|
|
||||||
topk = req.output_topk_p.size(0)
|
|
||||||
|
|
||||||
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
|
|
||||||
req.output_topk_p
|
|
||||||
)
|
|
||||||
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
|
|
||||||
req.output_topk_index
|
|
||||||
)
|
|
||||||
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
||||||
req.hidden_states_tensor
|
req.hidden_states_tensor
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -711,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"sgl-kernel",
|
"sgl-kernel",
|
||||||
"0.3.12",
|
"0.3.11",
|
||||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -12,8 +11,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
import zmq
|
import zmq
|
||||||
@@ -81,10 +79,11 @@ class GrpcReqState:
|
|||||||
last_completion_tokens: int = 1
|
last_completion_tokens: int = 1
|
||||||
|
|
||||||
# Streaming state
|
# Streaming state
|
||||||
|
last_output_offset: int = 0
|
||||||
stream_finished: bool = False
|
stream_finished: bool = False
|
||||||
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
|
|
||||||
|
|
||||||
# Token accumulation (for non-streaming)
|
# Output accumulation
|
||||||
|
text: str = ""
|
||||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||||
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||||
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||||
@@ -140,6 +139,8 @@ class GrpcRequestManager:
|
|||||||
self.is_pause_cond = asyncio.Condition()
|
self.is_pause_cond = asyncio.Condition()
|
||||||
|
|
||||||
# Metrics
|
# Metrics
|
||||||
|
self.request_counter = 0
|
||||||
|
self.request_counter_lock = asyncio.Lock()
|
||||||
self.last_receive_tstamp = time.time()
|
self.last_receive_tstamp = time.time()
|
||||||
|
|
||||||
# Crash dump for debugging
|
# Crash dump for debugging
|
||||||
@@ -157,133 +158,22 @@ class GrpcRequestManager:
|
|||||||
obj: TokenizedGenerateReqInput,
|
obj: TokenizedGenerateReqInput,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||||
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
|
) -> asyncio.Queue:
|
||||||
"""
|
"""
|
||||||
Submit a generation request to the scheduler with n>1 parallel sampling support.
|
Submit a generation request to the scheduler.
|
||||||
|
Returns a queue for streaming outputs.
|
||||||
This method implements the same two-phase approach as tokenizer_manager.py:
|
|
||||||
1. Phase 1: Send prefix caching request (max_new_tokens=0)
|
|
||||||
2. Phase 2: Send n generation requests that reuse the cached prefix
|
|
||||||
|
|
||||||
Yields individual responses for streaming, or aggregated responses for non-streaming.
|
|
||||||
"""
|
"""
|
||||||
n = getattr(obj.sampling_params, "n", 1)
|
|
||||||
|
|
||||||
if n <= 1:
|
|
||||||
async for response in self._handle_single_request(
|
|
||||||
obj, request_id, grpc_context
|
|
||||||
):
|
|
||||||
yield response
|
|
||||||
return
|
|
||||||
|
|
||||||
# N>1 handling - two-phase approach
|
|
||||||
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
|
|
||||||
|
|
||||||
# Generate base request ID if not provided
|
|
||||||
if request_id is None:
|
|
||||||
base_request_id = f"grpc-{uuid.uuid4().hex}"
|
|
||||||
else:
|
|
||||||
base_request_id = request_id
|
|
||||||
|
|
||||||
# Phase 1: Cache the common prefix
|
|
||||||
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
|
|
||||||
prefix_obj = copy.copy(obj)
|
|
||||||
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
|
|
||||||
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
|
|
||||||
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
|
|
||||||
|
|
||||||
# Send prefix caching request and consume response
|
|
||||||
async for _ in self._handle_single_request(
|
|
||||||
prefix_obj, f"{base_request_id}-prefix", grpc_context
|
|
||||||
):
|
|
||||||
# Consume prefix response (usually just one chunk with finish_reason)
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
|
|
||||||
|
|
||||||
# Phase 2: Generate n parallel requests
|
|
||||||
logger.debug(f"Phase 2: Generating {n} parallel requests")
|
|
||||||
generators = []
|
|
||||||
request_ids = []
|
|
||||||
|
|
||||||
for i in range(n):
|
|
||||||
# Create individual generation request
|
|
||||||
gen_obj = copy.copy(obj)
|
|
||||||
gen_obj.sampling_params = copy.copy(obj.sampling_params)
|
|
||||||
gen_obj.sampling_params.n = 1 # Each request generates 1 response
|
|
||||||
|
|
||||||
gen_request_id = f"{base_request_id}-{i}"
|
|
||||||
request_ids.append(gen_request_id)
|
|
||||||
|
|
||||||
# Start generation request
|
|
||||||
generators.append(
|
|
||||||
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle response aggregation
|
|
||||||
is_stream = getattr(obj, "stream", False)
|
|
||||||
|
|
||||||
if not is_stream:
|
|
||||||
# Non-streaming: collect all responses and return as batch
|
|
||||||
logger.debug(f"Non-streaming mode: collecting {n} responses")
|
|
||||||
responses = []
|
|
||||||
for generator in generators:
|
|
||||||
async for response in generator:
|
|
||||||
responses.append(response)
|
|
||||||
yield responses # Return all responses as a batch
|
|
||||||
else:
|
|
||||||
# Streaming mode: multiplex responses with index for ordering
|
|
||||||
logger.debug(f"Streaming mode: multiplexing {n} streams")
|
|
||||||
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
|
|
||||||
|
|
||||||
# Create async tasks for all generators
|
|
||||||
task_map = {}
|
|
||||||
for generator in generators:
|
|
||||||
task = asyncio.create_task(generator.__anext__())
|
|
||||||
task_map[task] = generator
|
|
||||||
|
|
||||||
# Process responses as they arrive
|
|
||||||
while task_map:
|
|
||||||
done, _ = await asyncio.wait(
|
|
||||||
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
|
||||||
)
|
|
||||||
|
|
||||||
for task in done:
|
|
||||||
generator = task_map.pop(task)
|
|
||||||
try:
|
|
||||||
response = await task
|
|
||||||
|
|
||||||
# Add index for client-side ordering
|
|
||||||
if isinstance(response, dict) and "meta_info" in response:
|
|
||||||
response_rid = response["meta_info"].get("id", "")
|
|
||||||
if response_rid in rid_to_index:
|
|
||||||
response["index"] = rid_to_index[response_rid]
|
|
||||||
|
|
||||||
yield response
|
|
||||||
|
|
||||||
# Create next task for this generator
|
|
||||||
next_task = asyncio.create_task(generator.__anext__())
|
|
||||||
task_map[next_task] = generator
|
|
||||||
|
|
||||||
except StopAsyncIteration:
|
|
||||||
# This generator is finished
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _handle_single_request(
|
|
||||||
self,
|
|
||||||
obj: TokenizedGenerateReqInput,
|
|
||||||
request_id: Optional[str] = None,
|
|
||||||
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
|
||||||
):
|
|
||||||
"""Handle a single request - core implementation without n>1 logic."""
|
|
||||||
# Generate request ID if not provided
|
# Generate request ID if not provided
|
||||||
if request_id is None:
|
if request_id is None:
|
||||||
request_id = f"grpc-{uuid.uuid4().hex}"
|
async with self.request_counter_lock:
|
||||||
|
request_id = f"grpc-{self.request_counter}"
|
||||||
|
self.request_counter += 1
|
||||||
|
|
||||||
obj.rid = request_id
|
obj.rid = request_id
|
||||||
|
|
||||||
# Create and register request state
|
|
||||||
# TODO: support log_request
|
# TODO: support log_request
|
||||||
|
|
||||||
|
# Create request state
|
||||||
state = GrpcReqState(
|
state = GrpcReqState(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
grpc_context=grpc_context,
|
grpc_context=grpc_context,
|
||||||
@@ -299,51 +189,19 @@ class GrpcRequestManager:
|
|||||||
state.session_id = obj.session_params.session_id
|
state.session_id = obj.session_params.session_id
|
||||||
state.is_session_request = True
|
state.is_session_request = True
|
||||||
|
|
||||||
|
# Register state
|
||||||
self.rid_to_state[request_id] = state
|
self.rid_to_state[request_id] = state
|
||||||
self.record_request_for_crash_dump(obj)
|
self.record_request_for_crash_dump(obj)
|
||||||
|
|
||||||
|
# Send to scheduler via ZMQ
|
||||||
try:
|
try:
|
||||||
# Send to scheduler - let exceptions bubble up to grpc_server.py
|
|
||||||
await self._send_to_scheduler(obj)
|
await self._send_to_scheduler(obj)
|
||||||
|
except Exception as e:
|
||||||
is_stream = getattr(obj, "stream", False)
|
# Clean up on failure
|
||||||
|
|
||||||
while True:
|
|
||||||
# Client cancelled - notify scheduler and exit
|
|
||||||
if grpc_context and grpc_context.cancelled():
|
|
||||||
await self.abort_request(request_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
|
|
||||||
|
|
||||||
if is_stream:
|
|
||||||
yield response
|
|
||||||
|
|
||||||
# Non-streaming: yield final response with accumulated tokens from state
|
|
||||||
if isinstance(response, dict) and response.get("finished", False):
|
|
||||||
if not is_stream:
|
|
||||||
final_response = response.copy()
|
|
||||||
final_response["token_ids"] = state.output_ids
|
|
||||||
yield final_response
|
|
||||||
break
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Timeout waiting for response - abort and cleanup
|
|
||||||
logger.warning(
|
|
||||||
f"Timeout waiting for response for request {request_id}"
|
|
||||||
)
|
|
||||||
await self.abort_request(request_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Always clean up request state when exiting
|
|
||||||
self._cleanup_request_state(request_id)
|
|
||||||
|
|
||||||
def _cleanup_request_state(self, request_id: str):
|
|
||||||
"""Clean up local request state (does not notify scheduler)."""
|
|
||||||
if request_id in self.rid_to_state:
|
|
||||||
del self.rid_to_state[request_id]
|
del self.rid_to_state[request_id]
|
||||||
|
raise RuntimeError(f"Failed to send request to scheduler: {e}")
|
||||||
|
|
||||||
|
return state.out_queue
|
||||||
|
|
||||||
async def embedding_request(
|
async def embedding_request(
|
||||||
self,
|
self,
|
||||||
@@ -356,7 +214,9 @@ class GrpcRequestManager:
|
|||||||
"""
|
"""
|
||||||
# Generate request ID if not provided
|
# Generate request ID if not provided
|
||||||
if request_id is None:
|
if request_id is None:
|
||||||
request_id = f"grpc-embed-{uuid.uuid4().hex}"
|
async with self.request_counter_lock:
|
||||||
|
request_id = f"grpc-embed-{self.request_counter}"
|
||||||
|
self.request_counter += 1
|
||||||
|
|
||||||
obj.rid = request_id
|
obj.rid = request_id
|
||||||
|
|
||||||
@@ -495,6 +355,7 @@ class GrpcRequestManager:
|
|||||||
# Extract output for this request
|
# Extract output for this request
|
||||||
output_data = {
|
output_data = {
|
||||||
"request_id": rid,
|
"request_id": rid,
|
||||||
|
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
|
||||||
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
||||||
"finished": batch_out.finished_reasons[i] is not None,
|
"finished": batch_out.finished_reasons[i] is not None,
|
||||||
"meta_info": {
|
"meta_info": {
|
||||||
@@ -506,9 +367,6 @@ class GrpcRequestManager:
|
|||||||
if batch_out.completion_tokens
|
if batch_out.completion_tokens
|
||||||
else 0
|
else 0
|
||||||
),
|
),
|
||||||
"cached_tokens": (
|
|
||||||
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
|
|
||||||
),
|
|
||||||
"finish_reason": (
|
"finish_reason": (
|
||||||
str(batch_out.finished_reasons[i])
|
str(batch_out.finished_reasons[i])
|
||||||
if batch_out.finished_reasons[i]
|
if batch_out.finished_reasons[i]
|
||||||
@@ -517,110 +375,29 @@ class GrpcRequestManager:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Accumulate input logprobs (only once, usually in first chunk)
|
# Add logprobs if available
|
||||||
if batch_out.input_token_logprobs_val and i < len(
|
|
||||||
batch_out.input_token_logprobs_val
|
|
||||||
):
|
|
||||||
if not state.input_token_logprobs_val:
|
|
||||||
state.input_token_logprobs_val.extend(
|
|
||||||
batch_out.input_token_logprobs_val[i]
|
|
||||||
)
|
|
||||||
if batch_out.input_token_logprobs_idx and i < len(
|
|
||||||
batch_out.input_token_logprobs_idx
|
|
||||||
):
|
|
||||||
state.input_token_logprobs_idx.extend(
|
|
||||||
batch_out.input_token_logprobs_idx[i]
|
|
||||||
)
|
|
||||||
if batch_out.input_top_logprobs_val and i < len(
|
|
||||||
batch_out.input_top_logprobs_val
|
|
||||||
):
|
|
||||||
state.input_top_logprobs_val.extend(
|
|
||||||
batch_out.input_top_logprobs_val[i]
|
|
||||||
)
|
|
||||||
if batch_out.input_top_logprobs_idx and i < len(
|
|
||||||
batch_out.input_top_logprobs_idx
|
|
||||||
):
|
|
||||||
state.input_top_logprobs_idx.extend(
|
|
||||||
batch_out.input_top_logprobs_idx[i]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send input logprobs based on mode
|
|
||||||
if state.input_token_logprobs_val:
|
|
||||||
if state.obj.stream and not state.input_logprobs_sent:
|
|
||||||
# Streaming: send input logprobs once in first chunk that has them
|
|
||||||
output_data["input_logprobs"] = {
|
|
||||||
"token_logprobs_val": state.input_token_logprobs_val,
|
|
||||||
"token_logprobs_idx": state.input_token_logprobs_idx,
|
|
||||||
"top_logprobs_val": state.input_top_logprobs_val,
|
|
||||||
"top_logprobs_idx": state.input_top_logprobs_idx,
|
|
||||||
}
|
|
||||||
state.input_logprobs_sent = True
|
|
||||||
elif not state.obj.stream and output_data["finished"]:
|
|
||||||
# Non-streaming: send input logprobs in final chunk
|
|
||||||
output_data["input_logprobs"] = {
|
|
||||||
"token_logprobs_val": state.input_token_logprobs_val,
|
|
||||||
"token_logprobs_idx": state.input_token_logprobs_idx,
|
|
||||||
"top_logprobs_val": state.input_top_logprobs_val,
|
|
||||||
"top_logprobs_idx": state.input_top_logprobs_idx,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add output logprobs if available (RAW - no detokenization!)
|
|
||||||
if batch_out.output_token_logprobs_val and i < len(
|
if batch_out.output_token_logprobs_val and i < len(
|
||||||
batch_out.output_token_logprobs_val
|
batch_out.output_token_logprobs_val
|
||||||
):
|
):
|
||||||
# Accumulate in state first
|
output_data["logprobs"] = {
|
||||||
state.output_token_logprobs_val.extend(
|
"tokens": batch_out.output_token_logprobs_val[i],
|
||||||
batch_out.output_token_logprobs_val[i]
|
"top_logprobs": (
|
||||||
)
|
|
||||||
if batch_out.output_token_logprobs_idx and i < len(
|
|
||||||
batch_out.output_token_logprobs_idx
|
|
||||||
):
|
|
||||||
state.output_token_logprobs_idx.extend(
|
|
||||||
batch_out.output_token_logprobs_idx[i]
|
|
||||||
)
|
|
||||||
if batch_out.output_top_logprobs_val and i < len(
|
|
||||||
batch_out.output_top_logprobs_val
|
|
||||||
):
|
|
||||||
state.output_top_logprobs_val.extend(
|
|
||||||
batch_out.output_top_logprobs_val[i]
|
batch_out.output_top_logprobs_val[i]
|
||||||
)
|
if batch_out.output_top_logprobs_val
|
||||||
if batch_out.output_top_logprobs_idx and i < len(
|
and i < len(batch_out.output_top_logprobs_val)
|
||||||
batch_out.output_top_logprobs_idx
|
else None
|
||||||
):
|
),
|
||||||
state.output_top_logprobs_idx.extend(
|
}
|
||||||
batch_out.output_top_logprobs_idx[i]
|
|
||||||
)
|
|
||||||
|
|
||||||
if state.obj.stream:
|
# Update state
|
||||||
# For streaming: send incremental logprobs (only new tokens in this chunk)
|
if output_data["text"]:
|
||||||
# NOTE: this is different than TokenizerManager, which always accumulates
|
state.text += output_data["text"][state.last_output_offset :]
|
||||||
def get_part(attr_name):
|
state.last_output_offset = len(output_data["text"])
|
||||||
source_list = getattr(batch_out, attr_name, None)
|
|
||||||
return (
|
|
||||||
source_list[i]
|
|
||||||
if source_list and i < len(source_list)
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
output_data["output_logprobs"] = {
|
|
||||||
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
|
|
||||||
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
|
|
||||||
"top_logprobs_val": get_part("output_top_logprobs_val"),
|
|
||||||
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
|
|
||||||
}
|
|
||||||
elif output_data["finished"]:
|
|
||||||
# Non-streaming: send cumulative output logprobs in final chunk
|
|
||||||
output_data["output_logprobs"] = {
|
|
||||||
"token_logprobs_val": state.output_token_logprobs_val,
|
|
||||||
"token_logprobs_idx": state.output_token_logprobs_idx,
|
|
||||||
"top_logprobs_val": state.output_top_logprobs_val,
|
|
||||||
"top_logprobs_idx": state.output_top_logprobs_idx,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update state for accumulation
|
|
||||||
if output_data["token_ids"]:
|
if output_data["token_ids"]:
|
||||||
state.output_ids.extend(output_data["token_ids"])
|
state.output_ids.extend(output_data["token_ids"])
|
||||||
|
|
||||||
|
# Send to output queue
|
||||||
await state.out_queue.put(output_data)
|
await state.out_queue.put(output_data)
|
||||||
|
|
||||||
# Handle completion
|
# Handle completion
|
||||||
|
|||||||
@@ -181,34 +181,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
# Convert gRPC request to internal format
|
# Convert gRPC request to internal format
|
||||||
tokenized_req = self._convert_generate_request(request)
|
tokenized_req = self._convert_generate_request(request)
|
||||||
|
|
||||||
# Submit to request manager (automatically handles n>1)
|
# Submit to request manager
|
||||||
response_generator = self.request_manager.generate_request(
|
output_queue = await self.request_manager.generate_request(
|
||||||
obj=tokenized_req,
|
obj=tokenized_req,
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
grpc_context=context,
|
grpc_context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for output in response_generator:
|
# Stream outputs
|
||||||
# Handle batch responses (for n>1 non-streaming)
|
while True:
|
||||||
if isinstance(output, list):
|
try:
|
||||||
for batch_output in output:
|
# Get output with timeout
|
||||||
if "error" in batch_output:
|
output = await asyncio.wait_for(output_queue.get(), timeout=4)
|
||||||
yield sglang_scheduler_pb2.GenerateResponse(
|
|
||||||
request_id=request.request_id,
|
# Check for errors
|
||||||
error=sglang_scheduler_pb2.GenerateError(
|
|
||||||
message=batch_output["error"],
|
|
||||||
http_status_code=(
|
|
||||||
"500" if "abort" not in batch_output else "499"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# All non-error batch outputs are final responses
|
|
||||||
yield self._create_completion_response(
|
|
||||||
request.request_id, batch_output
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Handle single response (for streaming or n=1 non-streaming)
|
|
||||||
if "error" in output:
|
if "error" in output:
|
||||||
yield sglang_scheduler_pb2.GenerateResponse(
|
yield sglang_scheduler_pb2.GenerateResponse(
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
@@ -219,13 +205,27 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif output.get("finished", False):
|
break
|
||||||
|
|
||||||
|
# Check if finished
|
||||||
|
if output.get("finished", False):
|
||||||
|
# Send completion
|
||||||
yield self._create_completion_response(
|
yield self._create_completion_response(
|
||||||
request.request_id, output
|
request.request_id, output
|
||||||
)
|
)
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
|
# Send chunk
|
||||||
yield self._create_chunk_response(request.request_id, output)
|
yield self._create_chunk_response(request.request_id, output)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Check if context is still active
|
||||||
|
if context.cancelled():
|
||||||
|
# Abort the request
|
||||||
|
await self.request_manager.abort_request(request.request_id)
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
|
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
|
||||||
yield sglang_scheduler_pb2.GenerateResponse(
|
yield sglang_scheduler_pb2.GenerateResponse(
|
||||||
@@ -266,6 +266,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
prompt_tokens=result.get("prompt_tokens", 0),
|
prompt_tokens=result.get("prompt_tokens", 0),
|
||||||
cached_tokens=0,
|
cached_tokens=0,
|
||||||
embedding_dim=len(result["embedding"]),
|
embedding_dim=len(result["embedding"]),
|
||||||
|
generation_time=time.time() - self.start_time,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -321,14 +322,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
logger.info(f"Sending health check request to request manager...")
|
logger.info(f"Sending health check request to request manager...")
|
||||||
|
|
||||||
# Submit and wait for response
|
# Submit and wait for response
|
||||||
output_generator = self.request_manager.generate_request(
|
output_queue = await self.request_manager.generate_request(
|
||||||
health_request, request_id=rid
|
health_request, request_id=rid
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get first response with timeout
|
# Wait for response with configurable timeout
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
|
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
@@ -403,8 +404,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
return_logprob=grpc_req.return_logprob,
|
return_logprob=grpc_req.return_logprob,
|
||||||
logprob_start_len=grpc_req.logprob_start_len or -1,
|
logprob_start_len=grpc_req.logprob_start_len or -1,
|
||||||
top_logprobs_num=grpc_req.top_logprobs_num or 0,
|
top_logprobs_num=grpc_req.top_logprobs_num or 0,
|
||||||
stream=grpc_req.stream or False,
|
stream=True, # Always stream for gRPC
|
||||||
lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
|
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
|
||||||
token_ids_logprob=(
|
token_ids_logprob=(
|
||||||
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
||||||
),
|
),
|
||||||
@@ -437,7 +438,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
regex = None
|
regex = None
|
||||||
json_schema = None
|
json_schema = None
|
||||||
ebnf_grammar = None
|
ebnf_grammar = None
|
||||||
structural_tag = None
|
|
||||||
|
|
||||||
if grpc_params.HasField("regex"):
|
if grpc_params.HasField("regex"):
|
||||||
regex = grpc_params.regex
|
regex = grpc_params.regex
|
||||||
@@ -445,8 +445,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
json_schema = grpc_params.json_schema
|
json_schema = grpc_params.json_schema
|
||||||
elif grpc_params.HasField("ebnf_grammar"):
|
elif grpc_params.HasField("ebnf_grammar"):
|
||||||
ebnf_grammar = grpc_params.ebnf_grammar
|
ebnf_grammar = grpc_params.ebnf_grammar
|
||||||
elif grpc_params.HasField("structural_tag"):
|
|
||||||
structural_tag = grpc_params.structural_tag
|
|
||||||
|
|
||||||
return SGLSamplingParams(
|
return SGLSamplingParams(
|
||||||
temperature=grpc_params.temperature or 1.0,
|
temperature=grpc_params.temperature or 1.0,
|
||||||
@@ -458,74 +456,33 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
repetition_penalty=grpc_params.repetition_penalty or 1.0,
|
repetition_penalty=grpc_params.repetition_penalty or 1.0,
|
||||||
max_new_tokens=grpc_params.max_new_tokens or 128,
|
max_new_tokens=grpc_params.max_new_tokens or 128,
|
||||||
min_new_tokens=grpc_params.min_new_tokens or 0,
|
min_new_tokens=grpc_params.min_new_tokens or 0,
|
||||||
stop=list(grpc_params.stop) if grpc_params.stop else [],
|
stop=list(grpc_params.stop) if grpc_params.stop else None,
|
||||||
stop_token_ids=(
|
stop_token_ids=(
|
||||||
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
|
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
|
||||||
),
|
),
|
||||||
skip_special_tokens=grpc_params.skip_special_tokens,
|
skip_special_tokens=grpc_params.skip_special_tokens,
|
||||||
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
|
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
|
||||||
regex=regex,
|
regex=regex,
|
||||||
json_schema=json_schema,
|
json_schema=json_schema,
|
||||||
ebnf=ebnf_grammar,
|
ebnf=ebnf_grammar,
|
||||||
structural_tag=structural_tag,
|
|
||||||
n=grpc_params.n or 1,
|
n=grpc_params.n or 1,
|
||||||
ignore_eos=grpc_params.ignore_eos,
|
ignore_eos=grpc_params.ignore_eos,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_logprobs_to_proto(
|
|
||||||
self, logprobs_data: Dict
|
|
||||||
) -> Optional[sglang_scheduler_pb2.LogProbs]:
|
|
||||||
"""Convert logprobs dict to proto LogProbs format (transport RAW data only)."""
|
|
||||||
if not logprobs_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
|
|
||||||
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
|
|
||||||
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
|
|
||||||
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
|
|
||||||
|
|
||||||
# Build TopLogProbs entries
|
|
||||||
top_logprobs_proto = []
|
|
||||||
if top_logprobs_val and top_logprobs_idx:
|
|
||||||
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
|
|
||||||
top_logprobs_proto.append(
|
|
||||||
sglang_scheduler_pb2.TopLogProbs(
|
|
||||||
values=val_list,
|
|
||||||
token_ids=idx_list,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return sglang_scheduler_pb2.LogProbs(
|
|
||||||
token_logprobs=token_logprobs_val,
|
|
||||||
token_ids=token_logprobs_idx,
|
|
||||||
top_logprobs=top_logprobs_proto,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_chunk_response(
|
def _create_chunk_response(
|
||||||
self, request_id: str, output: Dict
|
self, request_id: str, output: Dict
|
||||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||||
"""Create a streaming chunk response."""
|
"""Create a streaming chunk response."""
|
||||||
meta_info = output.get("meta_info", {})
|
|
||||||
|
|
||||||
# Convert output logprobs if present
|
|
||||||
output_logprobs_proto = self._convert_logprobs_to_proto(
|
|
||||||
output.get("output_logprobs")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert input logprobs if present (only in first chunk)
|
|
||||||
input_logprobs_proto = self._convert_logprobs_to_proto(
|
|
||||||
output.get("input_logprobs")
|
|
||||||
)
|
|
||||||
|
|
||||||
return sglang_scheduler_pb2.GenerateResponse(
|
return sglang_scheduler_pb2.GenerateResponse(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
||||||
token_ids=output.get("token_ids", []),
|
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
|
||||||
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
text=output.get("text", ""),
|
||||||
completion_tokens=meta_info.get("completion_tokens", 0),
|
prompt_tokens=0,
|
||||||
cached_tokens=meta_info.get("cached_tokens", 0),
|
completion_tokens=len(output.get("token_ids", [])),
|
||||||
output_logprobs=output_logprobs_proto,
|
cached_tokens=0,
|
||||||
input_logprobs=input_logprobs_proto,
|
generation_time=time.time() - self.start_time,
|
||||||
|
queue_time=0.0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -534,56 +491,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||||
"""Create a completion response."""
|
"""Create a completion response."""
|
||||||
|
|
||||||
# Extract meta info and finish reason details
|
# Determine finish reason
|
||||||
|
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
|
||||||
meta_info = output.get("meta_info", {})
|
meta_info = output.get("meta_info", {})
|
||||||
finish_reason_data = meta_info.get("finish_reason")
|
if meta_info.get("finish_reason") == "length":
|
||||||
|
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
|
||||||
# Determine finish reason, default is stop
|
elif meta_info.get("finish_reason") == "eos_token":
|
||||||
finish_reason = "stop"
|
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
|
||||||
if finish_reason_data:
|
|
||||||
if isinstance(finish_reason_data, dict):
|
|
||||||
finish_reason_type = finish_reason_data.get("type")
|
|
||||||
else:
|
|
||||||
# Handle legacy string format
|
|
||||||
finish_reason_type = finish_reason_data
|
|
||||||
|
|
||||||
if finish_reason_type == "length":
|
|
||||||
finish_reason = "length"
|
|
||||||
elif finish_reason_type == "abort":
|
|
||||||
finish_reason = "abort"
|
|
||||||
|
|
||||||
# Extract matched_stop information
|
|
||||||
matched_stop_kwargs = {}
|
|
||||||
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
|
|
||||||
matched = finish_reason_data["matched"]
|
|
||||||
if isinstance(matched, int):
|
|
||||||
matched_stop_kwargs["matched_token_id"] = matched
|
|
||||||
elif isinstance(matched, str):
|
|
||||||
matched_stop_kwargs["matched_stop_str"] = matched
|
|
||||||
|
|
||||||
# Convert output logprobs if present
|
|
||||||
output_logprobs_proto = self._convert_logprobs_to_proto(
|
|
||||||
output.get("output_logprobs")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert input logprobs if present
|
|
||||||
input_logprobs_proto = self._convert_logprobs_to_proto(
|
|
||||||
output.get("input_logprobs")
|
|
||||||
)
|
|
||||||
|
|
||||||
return sglang_scheduler_pb2.GenerateResponse(
|
return sglang_scheduler_pb2.GenerateResponse(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
complete=sglang_scheduler_pb2.GenerateComplete(
|
complete=sglang_scheduler_pb2.GenerateComplete(
|
||||||
output_ids=output.get("token_ids", []),
|
output_ids=output.get("token_ids", []),
|
||||||
|
output_text=output.get("text", ""),
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
prompt_tokens=meta_info.get("prompt_tokens", 0),
|
|
||||||
completion_tokens=meta_info.get(
|
|
||||||
"completion_tokens", len(output.get("token_ids", []))
|
|
||||||
),
|
|
||||||
cached_tokens=meta_info.get("cached_tokens", 0),
|
|
||||||
output_logprobs=output_logprobs_proto,
|
|
||||||
input_logprobs=input_logprobs_proto,
|
|
||||||
**matched_stop_kwargs,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union
|
from typing import Any, Dict, List, Optional, TypeAlias, Union
|
||||||
|
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
ResponseFunctionToolCall,
|
ResponseFunctionToolCall,
|
||||||
@@ -228,15 +228,11 @@ class CompletionRequest(BaseModel):
|
|||||||
|
|
||||||
# For request id
|
# For request id
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Extra key for classifying the request (e.g. cache_salt)
|
|
||||||
extra_key: Optional[Union[List[str], str]] = None
|
|
||||||
# Cache salt for request caching
|
|
||||||
cache_salt: Optional[Union[List[str], str]] = None
|
|
||||||
# Priority for the request
|
# Priority for the request
|
||||||
priority: Optional[int] = None
|
priority: Optional[int] = None
|
||||||
|
|
||||||
# For custom metric labels
|
# For customer metric labels
|
||||||
custom_labels: Optional[Dict[str, str]] = None
|
customer_labels: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
@field_validator("max_tokens")
|
@field_validator("max_tokens")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -343,7 +339,7 @@ class FunctionResponse(BaseModel):
|
|||||||
"""Function response."""
|
"""Function response."""
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
arguments: Optional[str | Dict[str, Any]] = None
|
arguments: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
@@ -392,7 +388,7 @@ class Function(BaseModel):
|
|||||||
"""Function descriptions."""
|
"""Function descriptions."""
|
||||||
|
|
||||||
description: Optional[str] = Field(default=None, examples=[None])
|
description: Optional[str] = Field(default=None, examples=[None])
|
||||||
name: str
|
name: Optional[str] = None
|
||||||
parameters: Optional[object] = None
|
parameters: Optional[object] = None
|
||||||
strict: bool = False
|
strict: bool = False
|
||||||
|
|
||||||
@@ -549,10 +545,6 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
# For request id
|
# For request id
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Extra key for classifying the request (e.g. cache_salt)
|
|
||||||
extra_key: Optional[Union[List[str], str]] = None
|
|
||||||
# Cache salt for request caching
|
|
||||||
cache_salt: Optional[Union[List[str], str]] = None
|
|
||||||
# Priority for the request
|
# Priority for the request
|
||||||
priority: Optional[int] = None
|
priority: Optional[int] = None
|
||||||
|
|
||||||
@@ -786,13 +778,6 @@ class ResponsesRequest(BaseModel):
|
|||||||
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
|
description="The request_id related to this request. If the caller does not set it, a random uuid will be generated.",
|
||||||
)
|
)
|
||||||
priority: int = Field(default=0, description="Request priority")
|
priority: int = Field(default=0, description="Request priority")
|
||||||
extra_key: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Extra key for classifying the request (e.g. cache_salt)",
|
|
||||||
)
|
|
||||||
cache_salt: Optional[str] = Field(
|
|
||||||
default=None, description="Cache salt for request caching"
|
|
||||||
)
|
|
||||||
|
|
||||||
# SGLang-specific sampling parameters
|
# SGLang-specific sampling parameters
|
||||||
frequency_penalty: float = 0.0
|
frequency_penalty: float = 0.0
|
||||||
@@ -943,16 +928,6 @@ class MessageProcessingResult:
|
|||||||
tool_call_constraint: Optional[Any] = None
|
tool_call_constraint: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
class ToolCallProcessingResult(NamedTuple):
|
|
||||||
"""Result of processing tool calls in a response."""
|
|
||||||
|
|
||||||
tool_calls: Optional[
|
|
||||||
List[Any]
|
|
||||||
] # List of ToolCall objects or None if parsing failed
|
|
||||||
remaining_text: str # Text remaining after parsing tool calls
|
|
||||||
finish_reason: Dict[str, Any] # Updated finish reason dictionary
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseReasoningTextContent(BaseModel):
|
class ResponseReasoningTextContent(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
type: Literal["reasoning_text"] = "reasoning_text"
|
type: Literal["reasoning_text"] = "reasoning_text"
|
||||||
|
|||||||
@@ -27,10 +27,10 @@ class OpenAIServingBase(ABC):
|
|||||||
self.tokenizer_manager = tokenizer_manager
|
self.tokenizer_manager = tokenizer_manager
|
||||||
self.allowed_custom_labels = (
|
self.allowed_custom_labels = (
|
||||||
set(
|
set(
|
||||||
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
|
self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
|
||||||
)
|
)
|
||||||
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
|
if isinstance(self.tokenizer_manager.server_args, ServerArgs)
|
||||||
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_custom_labels
|
and self.tokenizer_manager.server_args.tokenizer_metrics_allowed_customer_labels
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,12 +62,6 @@ class OpenAIServingBase(ABC):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
|
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
|
||||||
return self.create_error_response(
|
|
||||||
message=str(e),
|
|
||||||
err_type="BadRequest",
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in request: {e}")
|
logger.exception(f"Error in request: {e}")
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
@@ -92,19 +86,6 @@ class OpenAIServingBase(ABC):
|
|||||||
|
|
||||||
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
|
return f"{self._request_id_prefix()}{uuid.uuid4().hex}"
|
||||||
|
|
||||||
def _compute_extra_key(self, request: OpenAIServingRequest) -> Optional[str]:
|
|
||||||
"""Compute the final extra_key by concatenating cache_salt and extra_key if both are provided."""
|
|
||||||
parts = []
|
|
||||||
for key in ["cache_salt", "extra_key"]:
|
|
||||||
value = getattr(request, key, None)
|
|
||||||
if value:
|
|
||||||
if not isinstance(value, str):
|
|
||||||
raise TypeError(
|
|
||||||
f"Value of {key} must be a string, but got {type(value).__name__}"
|
|
||||||
)
|
|
||||||
parts.append(value)
|
|
||||||
return "".join(parts) if parts else None
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _convert_to_internal_request(
|
def _convert_to_internal_request(
|
||||||
self,
|
self,
|
||||||
@@ -184,14 +165,14 @@ class OpenAIServingBase(ABC):
|
|||||||
)
|
)
|
||||||
return json.dumps({"error": error.model_dump()})
|
return json.dumps({"error": error.model_dump()})
|
||||||
|
|
||||||
def extract_custom_labels(self, raw_request):
|
def extract_customer_labels(self, raw_request):
|
||||||
if (
|
if (
|
||||||
not self.allowed_custom_labels
|
not self.allowed_custom_labels
|
||||||
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
or not self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
custom_labels = None
|
customer_labels = None
|
||||||
header = (
|
header = (
|
||||||
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
self.tokenizer_manager.server_args.tokenizer_metrics_custom_labels_header
|
||||||
)
|
)
|
||||||
@@ -206,9 +187,9 @@ class OpenAIServingBase(ABC):
|
|||||||
raw_labels = None
|
raw_labels = None
|
||||||
|
|
||||||
if isinstance(raw_labels, dict):
|
if isinstance(raw_labels, dict):
|
||||||
custom_labels = {
|
customer_labels = {
|
||||||
label: value
|
label: value
|
||||||
for label, value in raw_labels.items()
|
for label, value in raw_labels.items()
|
||||||
if label in self.allowed_custom_labels
|
if label in self.allowed_custom_labels
|
||||||
}
|
}
|
||||||
return custom_labels
|
return customer_labels
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Uni
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
from jsonschema import Draft202012Validator, SchemaError
|
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@@ -26,8 +25,6 @@ from sglang.srt.entrypoints.openai.protocol import (
|
|||||||
LogProbs,
|
LogProbs,
|
||||||
MessageProcessingResult,
|
MessageProcessingResult,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolCallProcessingResult,
|
|
||||||
ToolChoice,
|
|
||||||
TopLogprob,
|
TopLogprob,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
@@ -36,10 +33,7 @@ from sglang.srt.entrypoints.openai.utils import (
|
|||||||
process_hidden_states_from_ret,
|
process_hidden_states_from_ret,
|
||||||
to_openai_style_logprobs,
|
to_openai_style_logprobs,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call.core_types import ToolCallItem
|
|
||||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.function_call.json_array_parser import JsonArrayParser
|
|
||||||
from sglang.srt.function_call.utils import get_json_schema_constraint
|
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.parser.conversation import generate_chat_conv
|
from sglang.srt.parser.conversation import generate_chat_conv
|
||||||
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
||||||
@@ -64,7 +58,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
super().__init__(tokenizer_manager)
|
super().__init__(tokenizer_manager)
|
||||||
self.template_manager = template_manager
|
self.template_manager = template_manager
|
||||||
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
self.tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
|
||||||
self.reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
|
||||||
|
|
||||||
def _request_id_prefix(self) -> str:
|
def _request_id_prefix(self) -> str:
|
||||||
return "chatcmpl-"
|
return "chatcmpl-"
|
||||||
@@ -81,23 +74,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
):
|
):
|
||||||
return "Tools cannot be empty if tool choice is set to required."
|
return "Tools cannot be empty if tool choice is set to required."
|
||||||
|
|
||||||
if request.tool_choice is not None and not isinstance(request.tool_choice, str):
|
|
||||||
if not request.tools:
|
|
||||||
return "Tools cannot be empty if tool choice is set to a specific tool."
|
|
||||||
tool_name = request.tool_choice.function.name
|
|
||||||
tool_exists = any(tool.function.name == tool_name for tool in request.tools)
|
|
||||||
if not tool_exists:
|
|
||||||
return f"Tool '{tool_name}' not found in tools list."
|
|
||||||
|
|
||||||
# Validate tool definitions
|
|
||||||
for i, tool in enumerate(request.tools or []):
|
|
||||||
if tool.function.parameters is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
Draft202012Validator.check_schema(tool.function.parameters)
|
|
||||||
except SchemaError as e:
|
|
||||||
return f"Tool {i} function has invalid 'parameters' schema: {str(e)}"
|
|
||||||
|
|
||||||
max_output_tokens = request.max_completion_tokens or request.max_tokens
|
max_output_tokens = request.max_completion_tokens or request.max_tokens
|
||||||
server_context_length = self.tokenizer_manager.server_args.context_length
|
server_context_length = self.tokenizer_manager.server_args.context_length
|
||||||
if (
|
if (
|
||||||
@@ -152,8 +128,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}
|
||||||
|
|
||||||
# Extract custom labels from raw request headers
|
# Extract customer labels from raw request headers
|
||||||
custom_labels = self.extract_custom_labels(raw_request)
|
customer_labels = self.extract_customer_labels(raw_request)
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
@@ -173,9 +149,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
bootstrap_room=request.bootstrap_room,
|
bootstrap_room=request.bootstrap_room,
|
||||||
return_hidden_states=request.return_hidden_states,
|
return_hidden_states=request.return_hidden_states,
|
||||||
rid=request.rid,
|
rid=request.rid,
|
||||||
extra_key=self._compute_extra_key(request),
|
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
custom_labels=custom_labels,
|
customer_labels=customer_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, request
|
return adapted_request, request
|
||||||
@@ -213,14 +188,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
tool_call_constraint = parser.get_structure_constraint(
|
tool_call_constraint = parser.get_structure_constraint(
|
||||||
request.tool_choice
|
request.tool_choice
|
||||||
)
|
)
|
||||||
# Handle JSON schema constraint directly for required or named tool choice
|
|
||||||
if request.tool_choice == "required" or isinstance(
|
|
||||||
request.tool_choice, ToolChoice
|
|
||||||
):
|
|
||||||
json_schema = get_json_schema_constraint(
|
|
||||||
request.tools, request.tool_choice
|
|
||||||
)
|
|
||||||
tool_call_constraint = ("json_schema", json_schema)
|
|
||||||
|
|
||||||
# Use chat template
|
# Use chat template
|
||||||
if self.template_manager.chat_template_name is None:
|
if self.template_manager.chat_template_name is None:
|
||||||
@@ -468,10 +435,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
sampling_params[constraint_type] = convert_json_schema_to_str(
|
||||||
constraint_value.model_dump(by_alias=True)
|
constraint_value.model_dump(by_alias=True)
|
||||||
)
|
)
|
||||||
elif constraint_type == "json_schema":
|
|
||||||
sampling_params[constraint_type] = convert_json_schema_to_str(
|
|
||||||
constraint_value
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
sampling_params[constraint_type] = constraint_value
|
sampling_params[constraint_type] = constraint_value
|
||||||
return sampling_params
|
return sampling_params
|
||||||
@@ -564,7 +527,10 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
stream_buffers[index] = stream_buffer + delta
|
stream_buffers[index] = stream_buffer + delta
|
||||||
|
|
||||||
# Handle reasoning content
|
# Handle reasoning content
|
||||||
if self.reasoning_parser and request.separate_reasoning:
|
if (
|
||||||
|
self.tokenizer_manager.server_args.reasoning_parser
|
||||||
|
and request.separate_reasoning
|
||||||
|
):
|
||||||
reasoning_text, delta = self._process_reasoning_stream(
|
reasoning_text, delta = self._process_reasoning_stream(
|
||||||
index, delta, reasoning_parser_dict, content, request
|
index, delta, reasoning_parser_dict, content, request
|
||||||
)
|
)
|
||||||
@@ -754,7 +720,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
|
|
||||||
# Handle reasoning content
|
# Handle reasoning content
|
||||||
reasoning_text = None
|
reasoning_text = None
|
||||||
reasoning_parser = self.reasoning_parser
|
reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser
|
||||||
if reasoning_parser and request.separate_reasoning:
|
if reasoning_parser and request.separate_reasoning:
|
||||||
is_force_reasoning = (
|
is_force_reasoning = (
|
||||||
self.template_manager.force_reasoning
|
self.template_manager.force_reasoning
|
||||||
@@ -782,13 +748,8 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
and request.tools
|
and request.tools
|
||||||
and self.tool_call_parser
|
and self.tool_call_parser
|
||||||
):
|
):
|
||||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
|
||||||
tool_calls, text, finish_reason = self._process_tool_calls(
|
tool_calls, text, finish_reason = self._process_tool_calls(
|
||||||
text,
|
text, request.tools, finish_reason
|
||||||
request.tools,
|
|
||||||
finish_reason,
|
|
||||||
request.tool_choice,
|
|
||||||
history_tool_calls_cnt,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
@@ -878,76 +839,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
|
token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True)
|
||||||
return ChoiceLogprobs(content=token_logprobs)
|
return ChoiceLogprobs(content=token_logprobs)
|
||||||
|
|
||||||
def _process_tool_call_id(
|
|
||||||
self,
|
|
||||||
call_item: ToolCallItem,
|
|
||||||
history_tool_calls_cnt: int,
|
|
||||||
) -> str:
|
|
||||||
"""Process for generating a new and unique `tool_call_id`"""
|
|
||||||
if self.tool_call_parser != "kimi_k2":
|
|
||||||
# A simple uuid is sufficient for all models except for Kimi-K2.
|
|
||||||
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
|
||||||
return tool_call_id
|
|
||||||
else:
|
|
||||||
# Align with Kimi-K2 format: functions.{name}:{index}
|
|
||||||
# Kimi-K2 allows multiple tool_calls in one message; SGLang sets call_item.tool_index to the *local* position inside that message.
|
|
||||||
# Therefore, the index must be corrected by using `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
|
|
||||||
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
|
|
||||||
logger.debug(
|
|
||||||
f"Process tool call idx, parser: {self.tool_call_parser}, tool_call_id: {tool_call_id}, history_cnt: {history_tool_calls_cnt}"
|
|
||||||
)
|
|
||||||
return tool_call_id
|
|
||||||
|
|
||||||
def _process_tool_calls(
|
def _process_tool_calls(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
tools: List[Any],
|
tools: List[Any],
|
||||||
finish_reason: Dict[str, Any],
|
finish_reason: Dict[str, Any],
|
||||||
tool_choice: Optional[Union[str, ToolChoice]] = None,
|
) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]:
|
||||||
history_tool_calls_cnt: int = 0,
|
|
||||||
) -> ToolCallProcessingResult:
|
|
||||||
"""Process tool calls in the response"""
|
"""Process tool calls in the response"""
|
||||||
|
|
||||||
# Handle required or named tool choice
|
|
||||||
if tool_choice == "required" or (
|
|
||||||
isinstance(tool_choice, ToolChoice) and tool_choice.type == "function"
|
|
||||||
):
|
|
||||||
# Set finish reason to tool_calls since we're processing tool calls
|
|
||||||
if finish_reason["type"] == "stop":
|
|
||||||
finish_reason["type"] = "tool_calls"
|
|
||||||
finish_reason["matched"] = None
|
|
||||||
try:
|
|
||||||
# For required tool choice, we expect a JSON array of tool calls
|
|
||||||
tool_call_data = json.loads(text)
|
|
||||||
tool_calls = []
|
|
||||||
for i, tool in enumerate(tool_call_data):
|
|
||||||
# Create a ToolCallItem from the JSON data
|
|
||||||
call_info = ToolCallItem(
|
|
||||||
tool_index=i, # Use the loop index as tool_index
|
|
||||||
name=tool["name"],
|
|
||||||
parameters=json.dumps(tool["parameters"], ensure_ascii=False),
|
|
||||||
)
|
|
||||||
tool_id = self._process_tool_call_id(
|
|
||||||
call_info, history_tool_calls_cnt
|
|
||||||
)
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(
|
|
||||||
id=tool_id,
|
|
||||||
index=i,
|
|
||||||
function=FunctionResponse(
|
|
||||||
name=tool["name"],
|
|
||||||
arguments=json.dumps(
|
|
||||||
tool["parameters"], ensure_ascii=False
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ToolCallProcessingResult(tool_calls, "", finish_reason)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"Tool call parsing error: {e}")
|
|
||||||
return ToolCallProcessingResult(None, text, finish_reason)
|
|
||||||
|
|
||||||
# Use parser since output is not constrained by JSON schema
|
|
||||||
parser = FunctionCallParser(tools, self.tool_call_parser)
|
parser = FunctionCallParser(tools, self.tool_call_parser)
|
||||||
if parser.has_tool_call(text):
|
if parser.has_tool_call(text):
|
||||||
if finish_reason["type"] == "stop":
|
if finish_reason["type"] == "stop":
|
||||||
@@ -957,9 +855,15 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
text, call_info_list = parser.parse_non_stream(text)
|
text, call_info_list = parser.parse_non_stream(text)
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for call_info in call_info_list:
|
for call_info in call_info_list:
|
||||||
tool_id = self._process_tool_call_id(
|
# For Kimi-K2, align tool_call_id with the model format: functions.{name}:{index}
|
||||||
call_info, history_tool_calls_cnt
|
if (
|
||||||
)
|
self.tool_call_parser == "kimi_k2"
|
||||||
|
and call_info.name is not None
|
||||||
|
):
|
||||||
|
tool_id = f"functions.{call_info.name}:{call_info.tool_index}"
|
||||||
|
else:
|
||||||
|
tool_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=tool_id,
|
id=tool_id,
|
||||||
@@ -969,13 +873,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return ToolCallProcessingResult(tool_calls, text, finish_reason)
|
return tool_calls, text, finish_reason
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Tool call parsing error: {e}")
|
logger.error(f"Tool call parsing error: {e}")
|
||||||
# Return error but don't fail the whole request
|
# Return error but don't fail the whole request
|
||||||
return ToolCallProcessingResult(None, text, finish_reason)
|
return None, text, finish_reason
|
||||||
|
|
||||||
return ToolCallProcessingResult(None, text, finish_reason)
|
return None, text, finish_reason
|
||||||
|
|
||||||
def _process_streaming_logprobs(
|
def _process_streaming_logprobs(
|
||||||
self, content: Dict[str, Any], n_prev_token: int
|
self, content: Dict[str, Any], n_prev_token: int
|
||||||
@@ -1008,33 +912,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
or self._get_enable_thinking_from_request(request)
|
or self._get_enable_thinking_from_request(request)
|
||||||
)
|
)
|
||||||
reasoning_parser_dict[index] = ReasoningParser(
|
reasoning_parser_dict[index] = ReasoningParser(
|
||||||
self.reasoning_parser,
|
self.tokenizer_manager.server_args.reasoning_parser,
|
||||||
request.stream_reasoning,
|
request.stream_reasoning,
|
||||||
is_force_reasoning,
|
is_force_reasoning,
|
||||||
)
|
)
|
||||||
reasoning_parser = reasoning_parser_dict[index]
|
reasoning_parser = reasoning_parser_dict[index]
|
||||||
return reasoning_parser.parse_stream_chunk(delta)
|
return reasoning_parser.parse_stream_chunk(delta)
|
||||||
|
|
||||||
def _get_history_tool_calls_cnt(self, request: ChatCompletionRequest) -> int:
|
|
||||||
"""Counts the number of tool calls in the request's message history.
|
|
||||||
|
|
||||||
NOTE: This method is only useful for models that include self-increasing
|
|
||||||
history tool call idx in tool calls id, such as kimi-k2
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The chat completion request object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The total number of tool calls in the history, or 0 if not applicable.
|
|
||||||
"""
|
|
||||||
messages = getattr(request, "messages", [])
|
|
||||||
idx = 0
|
|
||||||
for msg in messages:
|
|
||||||
if msg.role == "assistant":
|
|
||||||
tool_calls = getattr(msg, "tool_calls", None)
|
|
||||||
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
|
|
||||||
return idx
|
|
||||||
|
|
||||||
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
|
def _get_enable_thinking_from_request(self, request: ChatCompletionRequest) -> bool:
|
||||||
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
||||||
|
|
||||||
@@ -1048,11 +932,11 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
"""
|
"""
|
||||||
if hasattr(request, "chat_template_kwargs") and request.chat_template_kwargs:
|
if hasattr(request, "chat_template_kwargs") and request.chat_template_kwargs:
|
||||||
# For Qwen3 models, `enable_thinking` is supported.
|
# For Qwen3 models, `enable_thinking` is supported.
|
||||||
if self.reasoning_parser in ["qwen3", "glm45"]:
|
if request.chat_template_kwargs.get("enable_thinking") is not None:
|
||||||
return request.chat_template_kwargs.get("enable_thinking", False)
|
return request.chat_template_kwargs.get("enable_thinking")
|
||||||
# For DeepSeek-V3.1 models, `thinking` is supported.
|
# For DeepSeek-V3.1 models, `thinking` is supported.
|
||||||
elif self.reasoning_parser in ["deepseek-v3"]:
|
elif request.chat_template_kwargs.get("thinking") is not None:
|
||||||
return request.chat_template_kwargs.get("thinking", False)
|
return request.chat_template_kwargs.get("thinking")
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
@@ -1068,25 +952,13 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
):
|
):
|
||||||
"""Process tool calls in streaming response"""
|
"""Process tool calls in streaming response"""
|
||||||
if index not in parser_dict:
|
if index not in parser_dict:
|
||||||
# Use JSON detector directly for required or named tool choice
|
parser_dict[index] = FunctionCallParser(
|
||||||
if request.tool_choice == "required" or isinstance(
|
tools=request.tools,
|
||||||
request.tool_choice, ToolChoice
|
tool_call_parser=self.tool_call_parser,
|
||||||
):
|
)
|
||||||
parser_dict[index] = JsonArrayParser()
|
|
||||||
else:
|
|
||||||
parser_dict[index] = FunctionCallParser(
|
|
||||||
tools=request.tools,
|
|
||||||
tool_call_parser=self.tool_call_parser,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser = parser_dict[index]
|
parser = parser_dict[index]
|
||||||
|
|
||||||
# Handle both FunctionCallParser and JsonArrayParser
|
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||||
if isinstance(parser, JsonArrayParser):
|
|
||||||
result = parser.parse_streaming_increment(delta, request.tools)
|
|
||||||
normal_text, calls = result.normal_text, result.calls
|
|
||||||
else:
|
|
||||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
|
||||||
|
|
||||||
# Yield normal text
|
# Yield normal text
|
||||||
if normal_text:
|
if normal_text:
|
||||||
@@ -1104,7 +976,6 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# Yield tool calls
|
# Yield tool calls
|
||||||
history_tool_calls_cnt = self._get_history_tool_calls_cnt(request)
|
|
||||||
for call_item in calls:
|
for call_item in calls:
|
||||||
# Mark that this choice has tool calls
|
# Mark that this choice has tool calls
|
||||||
has_tool_calls[index] = True
|
has_tool_calls[index] = True
|
||||||
@@ -1112,9 +983,11 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
# Tool call ID should be generated only once per tool call
|
# Tool call ID should be generated only once per tool call
|
||||||
if call_item.name:
|
if call_item.name:
|
||||||
# First chunk: include ID and function name
|
# First chunk: include ID and function name
|
||||||
tool_call_id = self._process_tool_call_id(
|
if self.tool_call_parser == "kimi_k2":
|
||||||
call_item, history_tool_calls_cnt
|
# Align with Kimi-K2 format: functions.{name}:{index}
|
||||||
)
|
tool_call_id = f"functions.{call_item.name}:{call_item.tool_index}"
|
||||||
|
else:
|
||||||
|
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
|
||||||
function_name = call_item.name
|
function_name = call_item.name
|
||||||
else:
|
else:
|
||||||
# Subsequent chunks: null ID and name for argument deltas
|
# Subsequent chunks: null ID and name for argument deltas
|
||||||
@@ -1145,7 +1018,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
|
|
||||||
def _check_for_unstreamed_tool_args(
|
def _check_for_unstreamed_tool_args(
|
||||||
self,
|
self,
|
||||||
parser: Union[FunctionCallParser, JsonArrayParser],
|
parser: FunctionCallParser,
|
||||||
content: Dict[str, Any],
|
content: Dict[str, Any],
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
index: int,
|
index: int,
|
||||||
@@ -1155,31 +1028,30 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
when generation finishes. This ensures tool calls are properly completed
|
when generation finishes. This ensures tool calls are properly completed
|
||||||
even if the model generates the final arguments in the last chunk.
|
even if the model generates the final arguments in the last chunk.
|
||||||
"""
|
"""
|
||||||
# Get the detector - either from FunctionCallParser or directly if json detector
|
# Only check if we have tool calls and the parser has tracked data
|
||||||
detector = parser.detector if hasattr(parser, "detector") else parser
|
|
||||||
|
|
||||||
# Only check if we have tool calls and the detector has tracked data
|
|
||||||
if (
|
if (
|
||||||
not hasattr(detector, "prev_tool_call_arr")
|
not hasattr(parser.detector, "prev_tool_call_arr")
|
||||||
or not detector.prev_tool_call_arr
|
or not parser.detector.prev_tool_call_arr
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not hasattr(detector, "streamed_args_for_tool")
|
not hasattr(parser.detector, "streamed_args_for_tool")
|
||||||
or not detector.streamed_args_for_tool
|
or not parser.detector.streamed_args_for_tool
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get the last tool call that was being processed
|
# Get the last tool call that was being processed
|
||||||
tool_index = len(detector.prev_tool_call_arr) - 1
|
tool_index = len(parser.detector.prev_tool_call_arr) - 1
|
||||||
if tool_index < 0 or tool_index >= len(detector.streamed_args_for_tool):
|
if tool_index < 0 or tool_index >= len(parser.detector.streamed_args_for_tool):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get expected vs actual arguments
|
# Get expected vs actual arguments
|
||||||
expected_args = detector.prev_tool_call_arr[tool_index].get("arguments", {})
|
expected_args = parser.detector.prev_tool_call_arr[tool_index].get(
|
||||||
|
"arguments", {}
|
||||||
|
)
|
||||||
expected_call = json.dumps(expected_args, ensure_ascii=False)
|
expected_call = json.dumps(expected_args, ensure_ascii=False)
|
||||||
actual_call = detector.streamed_args_for_tool[tool_index]
|
actual_call = parser.detector.streamed_args_for_tool[tool_index]
|
||||||
|
|
||||||
# Check if there are remaining arguments to send
|
# Check if there are remaining arguments to send
|
||||||
remaining_call = (
|
remaining_call = (
|
||||||
|
|||||||
@@ -90,8 +90,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": prompt}
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
|
|
||||||
# Extract custom labels from raw request headers
|
# Extract customer labels from raw request headers
|
||||||
custom_labels = self.extract_custom_labels(raw_request)
|
customer_labels = self.extract_customer_labels(raw_request)
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
@@ -107,9 +107,8 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
bootstrap_room=request.bootstrap_room,
|
bootstrap_room=request.bootstrap_room,
|
||||||
return_hidden_states=request.return_hidden_states,
|
return_hidden_states=request.return_hidden_states,
|
||||||
rid=request.rid,
|
rid=request.rid,
|
||||||
extra_key=self._compute_extra_key(request),
|
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
custom_labels=custom_labels,
|
customer_labels=customer_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, request
|
return adapted_request, request
|
||||||
|
|||||||
@@ -245,7 +245,6 @@ class OpenAIServingResponses(OpenAIServingChat):
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
rid=request.request_id,
|
rid=request.request_id,
|
||||||
extra_key=self._compute_extra_key(request),
|
|
||||||
background=request.background,
|
background=request.background,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1251,7 +1250,6 @@ class OpenAIServingResponses(OpenAIServingChat):
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
stream=adapted_request.stream,
|
stream=adapted_request.stream,
|
||||||
rid=request_id,
|
rid=request_id,
|
||||||
extra_key=adapted_request.extra_key,
|
|
||||||
return_logprob=adapted_request.return_logprob,
|
return_logprob=adapted_request.return_logprob,
|
||||||
logprob_start_len=adapted_request.logprob_start_len,
|
logprob_start_len=adapted_request.logprob_start_len,
|
||||||
top_logprobs_num=adapted_request.top_logprobs_num,
|
top_logprobs_num=adapted_request.top_logprobs_num,
|
||||||
|
|||||||
@@ -231,7 +231,6 @@ class ExpertLocationMetadata:
|
|||||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||||
logical_to_rank_dispatch_physical_map=(
|
logical_to_rank_dispatch_physical_map=(
|
||||||
compute_logical_to_rank_dispatch_physical_map(
|
compute_logical_to_rank_dispatch_physical_map(
|
||||||
server_args=server_args,
|
|
||||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||||
num_gpus=ep_size,
|
num_gpus=ep_size,
|
||||||
num_physical_experts=num_physical_experts,
|
num_physical_experts=num_physical_experts,
|
||||||
@@ -341,7 +340,6 @@ def _pad_nested_array(arr, pad_value):
|
|||||||
|
|
||||||
# TODO optimize performance (rewrite and/or run in separate process with overlap)
|
# TODO optimize performance (rewrite and/or run in separate process with overlap)
|
||||||
def compute_logical_to_rank_dispatch_physical_map(
|
def compute_logical_to_rank_dispatch_physical_map(
|
||||||
server_args: ServerArgs,
|
|
||||||
logical_to_all_physical_map: torch.Tensor,
|
logical_to_all_physical_map: torch.Tensor,
|
||||||
num_gpus: int,
|
num_gpus: int,
|
||||||
num_physical_experts: int,
|
num_physical_experts: int,
|
||||||
@@ -350,9 +348,7 @@ def compute_logical_to_rank_dispatch_physical_map(
|
|||||||
):
|
):
|
||||||
r = random.Random(seed)
|
r = random.Random(seed)
|
||||||
|
|
||||||
num_local_gpu_physical_experts = num_physical_experts // num_gpus
|
num_local_physical_experts = num_physical_experts // num_gpus
|
||||||
num_gpus_per_node = server_args.ep_size // server_args.nnodes
|
|
||||||
num_local_node_physical_experts = num_local_gpu_physical_experts * num_gpus_per_node
|
|
||||||
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
||||||
dtype = logical_to_all_physical_map.dtype
|
dtype = logical_to_all_physical_map.dtype
|
||||||
|
|
||||||
@@ -376,28 +372,13 @@ def compute_logical_to_rank_dispatch_physical_map(
|
|||||||
physical_expert_id
|
physical_expert_id
|
||||||
for physical_expert_id in candidate_physical_expert_ids
|
for physical_expert_id in candidate_physical_expert_ids
|
||||||
if _compute_gpu_id_of_physical_expert(
|
if _compute_gpu_id_of_physical_expert(
|
||||||
physical_expert_id, num_local_gpu_physical_experts
|
physical_expert_id, num_local_physical_experts
|
||||||
)
|
)
|
||||||
== gpu_id
|
== gpu_id
|
||||||
]
|
]
|
||||||
if len(same_gpu_physical_expert_ids) > 0:
|
if len(same_gpu_physical_expert_ids) > 0:
|
||||||
# 1. Prefer same-GPU experts
|
|
||||||
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
|
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
|
||||||
else:
|
|
||||||
# 2. Otherwise, prefer same-node experts
|
|
||||||
node_id = gpu_id // num_gpus_per_node
|
|
||||||
same_node_physical_expert_ids = [
|
|
||||||
physical_expert_id
|
|
||||||
for physical_expert_id in candidate_physical_expert_ids
|
|
||||||
if _compute_node_id_of_physical_expert(
|
|
||||||
physical_expert_id, num_local_node_physical_experts
|
|
||||||
)
|
|
||||||
== node_id
|
|
||||||
]
|
|
||||||
if len(same_node_physical_expert_ids) > 0:
|
|
||||||
output_partial[gpu_id] = same_node_physical_expert_ids[0]
|
|
||||||
|
|
||||||
# 3. Fill remaining slots with fair random choices
|
|
||||||
num_remain = torch.sum(output_partial == -1).item()
|
num_remain = torch.sum(output_partial == -1).item()
|
||||||
output_partial[output_partial == -1] = torch.tensor(
|
output_partial[output_partial == -1] = torch.tensor(
|
||||||
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
|
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
|
||||||
@@ -423,15 +404,9 @@ def _logical_to_all_physical_raw(
|
|||||||
|
|
||||||
|
|
||||||
def _compute_gpu_id_of_physical_expert(
|
def _compute_gpu_id_of_physical_expert(
|
||||||
physical_expert_id: int, num_local_gpu_physical_experts: int
|
physical_expert_id: int, num_local_physical_experts: int
|
||||||
) -> int:
|
) -> int:
|
||||||
return physical_expert_id // num_local_gpu_physical_experts
|
return physical_expert_id // num_local_physical_experts
|
||||||
|
|
||||||
|
|
||||||
def _compute_node_id_of_physical_expert(
|
|
||||||
physical_expert_id: int, num_local_host_physical_experts: int
|
|
||||||
) -> int:
|
|
||||||
return physical_expert_id // num_local_host_physical_experts
|
|
||||||
|
|
||||||
|
|
||||||
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
|
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from sglang.srt.function_call.pythonic_detector import PythonicDetector
|
|||||||
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
from sglang.srt.function_call.qwen3_coder_detector import Qwen3CoderDetector
|
||||||
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
|
||||||
from sglang.srt.function_call.step3_detector import Step3Detector
|
from sglang.srt.function_call.step3_detector import Step3Detector
|
||||||
from sglang.srt.function_call.utils import get_json_schema_constraint
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -179,8 +178,8 @@ class FunctionCallParser:
|
|||||||
strict_tag = self.get_structure_tag()
|
strict_tag = self.get_structure_tag()
|
||||||
return ("structural_tag", strict_tag)
|
return ("structural_tag", strict_tag)
|
||||||
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
|
||||||
json_schema = get_json_schema_constraint(self.tools, tool_choice)
|
ebnf = self.get_ebnf(tool_choice)
|
||||||
return ("json_schema", json_schema)
|
return ("ebnf", ebnf) if ebnf is not None else None
|
||||||
|
|
||||||
def get_ebnf(
|
def get_ebnf(
|
||||||
self, tool_choice: Union[ToolChoice, Literal["required"]]
|
self, tool_choice: Union[ToolChoice, Literal["required"]]
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def parse_arguments(json_value):
|
|||||||
|
|
||||||
class Glm4MoeDetector(BaseFormatDetector):
|
class Glm4MoeDetector(BaseFormatDetector):
|
||||||
"""
|
"""
|
||||||
Detector for GLM-4.5 and GLM-4.6 models.
|
Detector for GLM-4.5 models.
|
||||||
Assumes function call format:
|
Assumes function call format:
|
||||||
<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>北京</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>\n<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>上海</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>
|
<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>北京</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>\n<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>上海</arg_value>\n<arg_key>date</arg_key>\n<arg_value>2024-06-27</arg_value>\n</tool_call>
|
||||||
"""
|
"""
|
||||||
@@ -53,7 +53,7 @@ class Glm4MoeDetector(BaseFormatDetector):
|
|||||||
self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
|
self.func_arg_regex = r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>"
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
def has_tool_call(self, text: str) -> bool:
|
||||||
"""Check if the text contains a glm-4.5 / glm-4.6 format tool call."""
|
"""Check if the text contains a glm-4.5 format tool call."""
|
||||||
return self.bot_token in text
|
return self.bot_token in text
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
@@ -102,7 +102,7 @@ class Glm4MoeDetector(BaseFormatDetector):
|
|||||||
self, new_text: str, tools: List[Tool]
|
self, new_text: str, tools: List[Tool]
|
||||||
) -> StreamingParseResult:
|
) -> StreamingParseResult:
|
||||||
"""
|
"""
|
||||||
Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.
|
Streaming incremental parsing tool calls for GLM-4.5 format.
|
||||||
"""
|
"""
|
||||||
self._buffer += new_text
|
self._buffer += new_text
|
||||||
current_text = self._buffer
|
current_text = self._buffer
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
import json
|
|
||||||
import re
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import Tool
|
|
||||||
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
|
||||||
from sglang.srt.function_call.core_types import StreamingParseResult
|
|
||||||
|
|
||||||
|
|
||||||
class JsonArrayParser(BaseFormatDetector):
|
|
||||||
"""
|
|
||||||
Parser for JSON array tool calls when JSON schema constraints are active.
|
|
||||||
|
|
||||||
This parser is used when tool_choice="required" or a specific tool is named,
|
|
||||||
bypassing model-specific parsers in favor of direct JSON array parsing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
# Configure for JSON array parsing
|
|
||||||
self.bot_token = "["
|
|
||||||
self.eot_token = "]"
|
|
||||||
self.tool_call_separator = ","
|
|
||||||
|
|
||||||
def has_tool_call(self, text: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the given text contains a JSON tool call (array or single object).
|
|
||||||
"""
|
|
||||||
return "[" in text or "{" in text
|
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
||||||
"""
|
|
||||||
Parse JSON tool calls using the base class implementation.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Detect and parse not supported for JSON schema constraints."
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_ebnf(self, tools: List[Tool]) -> str:
|
|
||||||
"""
|
|
||||||
Build an EBNF grammar for constrained generation.
|
|
||||||
This is not used for JSON schema constraints as they are handled
|
|
||||||
by the constraint backends directly.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"EBNF generation is not supported for JSON schema constraints."
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_streaming_increment(
|
|
||||||
self, new_text: str, tools: List[Tool]
|
|
||||||
) -> StreamingParseResult:
|
|
||||||
"""
|
|
||||||
Streaming incremental parsing with tool validation.
|
|
||||||
"""
|
|
||||||
return super().parse_streaming_increment(new_text, tools)
|
|
||||||
|
|
||||||
def structure_info(self) -> callable:
|
|
||||||
"""
|
|
||||||
Return a function that creates StructureInfo for constrained generation.
|
|
||||||
This is not used for JSON schema constraints as they are handled
|
|
||||||
by the constraint backends directly.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("structure_info not used for JSON schema constraints")
|
|
||||||
@@ -1,13 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
from json import JSONDecodeError, JSONDecoder
|
from json import JSONDecodeError, JSONDecoder
|
||||||
from json.decoder import WHITESPACE
|
from typing import Any, Tuple
|
||||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import Tool, ToolChoice
|
|
||||||
|
|
||||||
|
|
||||||
def _find_common_prefix(s1: str, s2: str) -> str:
|
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||||
prefix = ""
|
prefix = ""
|
||||||
@@ -40,12 +37,10 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||||
except (JSONDecodeError, IndexError) as e:
|
except JSONDecodeError as e:
|
||||||
msg = getattr(e, "msg", str(e))
|
if "Extra data" in e.msg:
|
||||||
if "Extra data" in msg or "pop from empty list" in msg:
|
dec = JSONDecoder()
|
||||||
start = WHITESPACE.match(input_str, 0).end()
|
return dec.raw_decode(input_str)
|
||||||
obj, end = JSONDecoder().raw_decode(input_str, start)
|
|
||||||
return obj, end
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@@ -55,89 +50,3 @@ def _is_complete_json(input_str: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_schema_defs(tools: List[Tool]) -> dict:
|
|
||||||
"""
|
|
||||||
Get consolidated $defs from all tools, validating for conflicts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tools: List of tools to process
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of consolidated $defs from all tools
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If conflicting $defs are found
|
|
||||||
"""
|
|
||||||
all_defs = {}
|
|
||||||
for tool in tools:
|
|
||||||
if tool.function.parameters is None:
|
|
||||||
continue
|
|
||||||
defs = tool.function.parameters.get("$defs", {})
|
|
||||||
for def_name, def_schema in defs.items():
|
|
||||||
if def_name in all_defs and all_defs[def_name] != def_schema:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tool definition '{def_name}' has "
|
|
||||||
"multiple schemas, which is not "
|
|
||||||
"supported."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
all_defs[def_name] = def_schema
|
|
||||||
return all_defs
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_schema(tool: Tool) -> dict:
|
|
||||||
return {
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string", "enum": [tool.function.name]},
|
|
||||||
"parameters": (
|
|
||||||
tool.function.parameters
|
|
||||||
if tool.function.parameters
|
|
||||||
else {"type": "object", "properties": {}}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"required": ["name", "parameters"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_json_schema_constraint(
|
|
||||||
tools: List[Tool], tool_choice: Union[ToolChoice, Literal["required"]]
|
|
||||||
) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Get the JSON schema constraint for the specified tool choice.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_choice: The tool choice specification
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON schema dict, or None if no valid tools found
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(tool_choice, ToolChoice):
|
|
||||||
# For specific function choice, return the user's parameters schema directly
|
|
||||||
fn_name = tool_choice.function.name
|
|
||||||
for tool in tools:
|
|
||||||
if tool.function.name == fn_name:
|
|
||||||
return {
|
|
||||||
"type": "array",
|
|
||||||
"minItems": 1,
|
|
||||||
"maxItems": 1,
|
|
||||||
"items": _get_tool_schema(tool),
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
elif tool_choice == "required":
|
|
||||||
json_schema = {
|
|
||||||
"type": "array",
|
|
||||||
"minItems": 1,
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"anyOf": [_get_tool_schema(tool) for tool in tools],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
json_schema_defs = _get_tool_schema_defs(tools)
|
|
||||||
if json_schema_defs:
|
|
||||||
json_schema["$defs"] = json_schema_defs
|
|
||||||
return json_schema
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ message SamplingParams {
|
|||||||
float presence_penalty = 6;
|
float presence_penalty = 6;
|
||||||
float repetition_penalty = 7;
|
float repetition_penalty = 7;
|
||||||
|
|
||||||
optional int32 max_new_tokens = 8;
|
int32 max_new_tokens = 8;
|
||||||
repeated string stop = 9;
|
repeated string stop = 9;
|
||||||
repeated uint32 stop_token_ids = 10;
|
repeated int32 stop_token_ids = 10;
|
||||||
bool skip_special_tokens = 11;
|
bool skip_special_tokens = 11;
|
||||||
bool spaces_between_special_tokens = 12;
|
bool spaces_between_special_tokens = 12;
|
||||||
|
|
||||||
@@ -47,24 +47,24 @@ message SamplingParams {
|
|||||||
string regex = 13;
|
string regex = 13;
|
||||||
string json_schema = 14;
|
string json_schema = 14;
|
||||||
string ebnf_grammar = 15;
|
string ebnf_grammar = 15;
|
||||||
string structural_tag = 16;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoRA adapter
|
// LoRA adapter
|
||||||
string lora_path = 17;
|
string lora_path = 16;
|
||||||
|
|
||||||
// Speculative decoding
|
// Speculative decoding
|
||||||
int32 n = 18; // Number of samples
|
int32 n = 17; // Number of samples
|
||||||
|
|
||||||
// Token healing
|
// Token healing
|
||||||
bool token_healing = 19;
|
bool token_healing = 18;
|
||||||
|
|
||||||
// Additional parameters
|
// Additional parameters
|
||||||
int32 min_new_tokens = 20;
|
int32 min_new_tokens = 19;
|
||||||
bool ignore_eos = 21;
|
bool ignore_eos = 20;
|
||||||
bool no_stop_trim = 22;
|
bool no_stop_trim = 21;
|
||||||
int32 stream_interval = 23;
|
int32 stream_interval = 22;
|
||||||
map<string, float> logit_bias = 24;
|
map<string, float> logit_bias = 23;
|
||||||
|
string structural_tag = 24;
|
||||||
|
|
||||||
// Custom parameters for extensibility
|
// Custom parameters for extensibility
|
||||||
google.protobuf.Struct custom_params = 25;
|
google.protobuf.Struct custom_params = 25;
|
||||||
@@ -98,7 +98,7 @@ message GenerateRequest {
|
|||||||
bool return_logprob = 5;
|
bool return_logprob = 5;
|
||||||
int32 logprob_start_len = 6;
|
int32 logprob_start_len = 6;
|
||||||
int32 top_logprobs_num = 7;
|
int32 top_logprobs_num = 7;
|
||||||
repeated uint32 token_ids_logprob = 8;
|
repeated int32 token_ids_logprob = 8;
|
||||||
bool return_hidden_states = 9;
|
bool return_hidden_states = 9;
|
||||||
|
|
||||||
// For disaggregated serving
|
// For disaggregated serving
|
||||||
@@ -122,14 +122,11 @@ message GenerateRequest {
|
|||||||
|
|
||||||
// For load balancing
|
// For load balancing
|
||||||
int32 dp_balance_id = 17;
|
int32 dp_balance_id = 17;
|
||||||
|
|
||||||
// Whether client wants streaming response
|
|
||||||
bool stream = 18;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message TokenizedInput {
|
message TokenizedInput {
|
||||||
string original_text = 1; // For reference
|
string original_text = 1; // For reference
|
||||||
repeated uint32 input_ids = 2;
|
repeated int32 input_ids = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message MultimodalInputs {
|
message MultimodalInputs {
|
||||||
@@ -166,50 +163,51 @@ message GenerateResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message GenerateStreamChunk {
|
message GenerateStreamChunk {
|
||||||
// Generated tokens (incremental chunk)
|
// Generated token
|
||||||
repeated uint32 token_ids = 1;
|
int32 token_id = 1;
|
||||||
|
string text = 2;
|
||||||
|
|
||||||
// Cumulative counts
|
// Cumulative counts
|
||||||
int32 prompt_tokens = 2;
|
|
||||||
int32 completion_tokens = 3;
|
|
||||||
int32 cached_tokens = 4;
|
|
||||||
|
|
||||||
// Output logprobs (if requested) - incremental for streaming
|
|
||||||
LogProbs output_logprobs = 5;
|
|
||||||
|
|
||||||
// Hidden states (if requested)
|
|
||||||
repeated float hidden_states = 6;
|
|
||||||
|
|
||||||
// Input logprobs (if requested) - only in first chunk
|
|
||||||
LogProbs input_logprobs = 7;
|
|
||||||
}
|
|
||||||
|
|
||||||
message GenerateComplete {
|
|
||||||
// Final output
|
|
||||||
repeated uint32 output_ids = 1;
|
|
||||||
|
|
||||||
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
|
|
||||||
string finish_reason = 2;
|
|
||||||
|
|
||||||
// Token usage counts
|
|
||||||
int32 prompt_tokens = 3;
|
int32 prompt_tokens = 3;
|
||||||
int32 completion_tokens = 4;
|
int32 completion_tokens = 4;
|
||||||
int32 cached_tokens = 5;
|
int32 cached_tokens = 5;
|
||||||
|
|
||||||
// Output logprobs if requested (cumulative)
|
// Logprobs (if requested)
|
||||||
LogProbs output_logprobs = 6;
|
LogProbs logprobs = 6;
|
||||||
|
|
||||||
|
// Hidden states (if requested)
|
||||||
|
repeated float hidden_states = 7;
|
||||||
|
|
||||||
|
// Metadata
|
||||||
|
float generation_time = 8; // Time to generate this token
|
||||||
|
int32 queue_time = 9; // Time spent in queue
|
||||||
|
}
|
||||||
|
|
||||||
|
message GenerateComplete {
|
||||||
|
// Final output
|
||||||
|
repeated int32 output_ids = 1;
|
||||||
|
string output_text = 2;
|
||||||
|
|
||||||
|
// Finish reason
|
||||||
|
enum FinishReason {
|
||||||
|
// The model generated a stop sequence.
|
||||||
|
STOP = 0;
|
||||||
|
// The model reached the maximum generation length.
|
||||||
|
LENGTH = 1;
|
||||||
|
// The model generated an end-of-sequence (EOS) token.
|
||||||
|
EOS_TOKEN = 2;
|
||||||
|
// The model generated a user-provided stop string.
|
||||||
|
STOP_STR = 3;
|
||||||
|
// The request was aborted by the user or system.
|
||||||
|
ABORT = 4;
|
||||||
|
}
|
||||||
|
FinishReason finish_reason = 3;
|
||||||
|
|
||||||
|
// All logprobs if requested
|
||||||
|
repeated LogProbs all_logprobs = 11;
|
||||||
|
|
||||||
// All hidden states if requested
|
// All hidden states if requested
|
||||||
repeated HiddenStates all_hidden_states = 7;
|
repeated HiddenStates all_hidden_states = 12;
|
||||||
|
|
||||||
// Matched stop information (for stop sequences)
|
|
||||||
oneof matched_stop {
|
|
||||||
uint32 matched_token_id = 8;
|
|
||||||
string matched_stop_str = 9;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Input logprobs if requested (for prompt tokens)
|
|
||||||
LogProbs input_logprobs = 10;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateError {
|
message GenerateError {
|
||||||
@@ -224,11 +222,15 @@ message LogProbs {
|
|||||||
|
|
||||||
// Top logprobs at each position
|
// Top logprobs at each position
|
||||||
repeated TopLogProbs top_logprobs = 3;
|
repeated TopLogProbs top_logprobs = 3;
|
||||||
|
|
||||||
|
// Decoded text for tokens
|
||||||
|
repeated string token_texts = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TopLogProbs {
|
message TopLogProbs {
|
||||||
repeated float values = 1;
|
repeated float values = 1;
|
||||||
repeated int32 token_ids = 2;
|
repeated int32 token_ids = 2;
|
||||||
|
repeated string token_texts = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message HiddenStates {
|
message HiddenStates {
|
||||||
@@ -283,9 +285,10 @@ message EmbedComplete {
|
|||||||
|
|
||||||
// Additional metadata
|
// Additional metadata
|
||||||
int32 embedding_dim = 4;
|
int32 embedding_dim = 4;
|
||||||
|
float generation_time = 5;
|
||||||
|
|
||||||
// For batch embeddings
|
// For batch embeddings
|
||||||
repeated Embedding batch_embeddings = 5;
|
repeated Embedding batch_embeddings = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Embedding {
|
message Embedding {
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -3,6 +3,7 @@ import datetime
|
|||||||
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
||||||
from google.protobuf import struct_pb2 as _struct_pb2
|
from google.protobuf import struct_pb2 as _struct_pb2
|
||||||
from google.protobuf.internal import containers as _containers
|
from google.protobuf.internal import containers as _containers
|
||||||
|
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
||||||
from google.protobuf import descriptor as _descriptor
|
from google.protobuf import descriptor as _descriptor
|
||||||
from google.protobuf import message as _message
|
from google.protobuf import message as _message
|
||||||
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
|
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
|
||||||
@@ -11,7 +12,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
|
|||||||
DESCRIPTOR: _descriptor.FileDescriptor
|
DESCRIPTOR: _descriptor.FileDescriptor
|
||||||
|
|
||||||
class SamplingParams(_message.Message):
|
class SamplingParams(_message.Message):
|
||||||
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
|
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
|
||||||
class LogitBiasEntry(_message.Message):
|
class LogitBiasEntry(_message.Message):
|
||||||
__slots__ = ("key", "value")
|
__slots__ = ("key", "value")
|
||||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||||
@@ -34,7 +35,6 @@ class SamplingParams(_message.Message):
|
|||||||
REGEX_FIELD_NUMBER: _ClassVar[int]
|
REGEX_FIELD_NUMBER: _ClassVar[int]
|
||||||
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
|
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
|
||||||
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
|
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
|
||||||
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
|
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
|
||||||
N_FIELD_NUMBER: _ClassVar[int]
|
N_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
|
||||||
@@ -43,6 +43,7 @@ class SamplingParams(_message.Message):
|
|||||||
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
|
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
|
||||||
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
|
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
|
||||||
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
|
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
|
||||||
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||||
temperature: float
|
temperature: float
|
||||||
top_p: float
|
top_p: float
|
||||||
@@ -59,7 +60,6 @@ class SamplingParams(_message.Message):
|
|||||||
regex: str
|
regex: str
|
||||||
json_schema: str
|
json_schema: str
|
||||||
ebnf_grammar: str
|
ebnf_grammar: str
|
||||||
structural_tag: str
|
|
||||||
lora_path: str
|
lora_path: str
|
||||||
n: int
|
n: int
|
||||||
token_healing: bool
|
token_healing: bool
|
||||||
@@ -68,8 +68,9 @@ class SamplingParams(_message.Message):
|
|||||||
no_stop_trim: bool
|
no_stop_trim: bool
|
||||||
stream_interval: int
|
stream_interval: int
|
||||||
logit_bias: _containers.ScalarMap[str, float]
|
logit_bias: _containers.ScalarMap[str, float]
|
||||||
|
structural_tag: str
|
||||||
custom_params: _struct_pb2.Struct
|
custom_params: _struct_pb2.Struct
|
||||||
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
class DisaggregatedParams(_message.Message):
|
class DisaggregatedParams(_message.Message):
|
||||||
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
|
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
|
||||||
@@ -82,7 +83,7 @@ class DisaggregatedParams(_message.Message):
|
|||||||
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
|
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateRequest(_message.Message):
|
class GenerateRequest(_message.Message):
|
||||||
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
|
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
|
||||||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||||
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
||||||
@@ -100,7 +101,6 @@ class GenerateRequest(_message.Message):
|
|||||||
LORA_ID_FIELD_NUMBER: _ClassVar[int]
|
LORA_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
||||||
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
|
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
STREAM_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
request_id: str
|
request_id: str
|
||||||
tokenized: TokenizedInput
|
tokenized: TokenizedInput
|
||||||
mm_inputs: MultimodalInputs
|
mm_inputs: MultimodalInputs
|
||||||
@@ -118,8 +118,7 @@ class GenerateRequest(_message.Message):
|
|||||||
lora_id: str
|
lora_id: str
|
||||||
data_parallel_rank: int
|
data_parallel_rank: int
|
||||||
dp_balance_id: int
|
dp_balance_id: int
|
||||||
stream: bool
|
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
|
||||||
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
|
|
||||||
|
|
||||||
class TokenizedInput(_message.Message):
|
class TokenizedInput(_message.Message):
|
||||||
__slots__ = ("original_text", "input_ids")
|
__slots__ = ("original_text", "input_ids")
|
||||||
@@ -162,46 +161,52 @@ class GenerateResponse(_message.Message):
|
|||||||
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateStreamChunk(_message.Message):
|
class GenerateStreamChunk(_message.Message):
|
||||||
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
|
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
|
||||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
token_id: int
|
||||||
|
text: str
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
cached_tokens: int
|
cached_tokens: int
|
||||||
output_logprobs: LogProbs
|
logprobs: LogProbs
|
||||||
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
||||||
input_logprobs: LogProbs
|
generation_time: float
|
||||||
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
|
queue_time: int
|
||||||
|
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateComplete(_message.Message):
|
class GenerateComplete(_message.Message):
|
||||||
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
|
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
|
||||||
|
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||||
|
__slots__ = ()
|
||||||
|
STOP: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
LENGTH: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
ABORT: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
STOP: GenerateComplete.FinishReason
|
||||||
|
LENGTH: GenerateComplete.FinishReason
|
||||||
|
EOS_TOKEN: GenerateComplete.FinishReason
|
||||||
|
STOP_STR: GenerateComplete.FinishReason
|
||||||
|
ABORT: GenerateComplete.FinishReason
|
||||||
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||||
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
||||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
|
||||||
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
finish_reason: str
|
output_text: str
|
||||||
prompt_tokens: int
|
finish_reason: GenerateComplete.FinishReason
|
||||||
completion_tokens: int
|
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
|
||||||
cached_tokens: int
|
|
||||||
output_logprobs: LogProbs
|
|
||||||
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
||||||
matched_token_id: int
|
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
|
||||||
matched_stop_str: str
|
|
||||||
input_logprobs: LogProbs
|
|
||||||
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[LogProbs, _Mapping]] = ...) -> None: ...
|
|
||||||
|
|
||||||
class GenerateError(_message.Message):
|
class GenerateError(_message.Message):
|
||||||
__slots__ = ("message", "http_status_code", "details")
|
__slots__ = ("message", "http_status_code", "details")
|
||||||
@@ -214,22 +219,26 @@ class GenerateError(_message.Message):
|
|||||||
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
class LogProbs(_message.Message):
|
class LogProbs(_message.Message):
|
||||||
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
|
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
|
||||||
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||||
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
|
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
|
||||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
|
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
|
||||||
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
|
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||||
|
|
||||||
class TopLogProbs(_message.Message):
|
class TopLogProbs(_message.Message):
|
||||||
__slots__ = ("values", "token_ids")
|
__slots__ = ("values", "token_ids", "token_texts")
|
||||||
VALUES_FIELD_NUMBER: _ClassVar[int]
|
VALUES_FIELD_NUMBER: _ClassVar[int]
|
||||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||||
values: _containers.RepeatedScalarFieldContainer[float]
|
values: _containers.RepeatedScalarFieldContainer[float]
|
||||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
|
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||||
|
|
||||||
class HiddenStates(_message.Message):
|
class HiddenStates(_message.Message):
|
||||||
__slots__ = ("values", "layer", "position")
|
__slots__ = ("values", "layer", "position")
|
||||||
@@ -274,18 +283,20 @@ class EmbedResponse(_message.Message):
|
|||||||
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
|
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
class EmbedComplete(_message.Message):
|
class EmbedComplete(_message.Message):
|
||||||
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
|
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
|
||||||
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
|
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
|
||||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
|
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||||
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
|
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
|
||||||
embedding: _containers.RepeatedScalarFieldContainer[float]
|
embedding: _containers.RepeatedScalarFieldContainer[float]
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
cached_tokens: int
|
cached_tokens: int
|
||||||
embedding_dim: int
|
embedding_dim: int
|
||||||
|
generation_time: float
|
||||||
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
|
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
|
||||||
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
|
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
|
||||||
|
|
||||||
class Embedding(_message.Message):
|
class Embedding(_message.Message):
|
||||||
__slots__ = ("values", "index")
|
__slots__ = ("values", "index")
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
# This file is auto-generated. Do not edit manually.
|
|
||||||
# Regenerate with: python compile_proto.py
|
|
||||||
|
|
||||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||||
"""Client and server classes corresponding to protobuf-defined services."""
|
"""Client and server classes corresponding to protobuf-defined services."""
|
||||||
import grpc
|
import grpc
|
||||||
|
|||||||
@@ -119,6 +119,37 @@ def get_hf_text_config(config: PretrainedConfig):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _load_deepseek_v32_model(
|
||||||
|
model_path: str,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# first get the local path
|
||||||
|
local_path = download_from_hf(model_path)
|
||||||
|
# then load the config file in json
|
||||||
|
config_file = os.path.join(local_path, "config.json")
|
||||||
|
if not os.path.exists(config_file):
|
||||||
|
raise RuntimeError(f"Can't find config file in {local_path}.")
|
||||||
|
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config_json = json.load(f)
|
||||||
|
|
||||||
|
config_json["architectures"] = ["DeepseekV3ForCausalLM"]
|
||||||
|
config_json["model_type"] = "deepseek_v3"
|
||||||
|
|
||||||
|
tmp_path = os.path.join(local_path, "_tmp_config_folder")
|
||||||
|
os.makedirs(tmp_path, exist_ok=True)
|
||||||
|
|
||||||
|
unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}")
|
||||||
|
with open(unique_path, "w") as f:
|
||||||
|
json.dump(config_json, f)
|
||||||
|
|
||||||
|
return AutoConfig.from_pretrained(
|
||||||
|
unique_path, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache_frozenset(maxsize=32)
|
@lru_cache_frozenset(maxsize=32)
|
||||||
def get_config(
|
def get_config(
|
||||||
model: str,
|
model: str,
|
||||||
@@ -140,9 +171,17 @@ def get_config(
|
|||||||
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||||
model = client.get_local_dir()
|
model = client.get_local_dir()
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
try:
|
||||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
config = AutoConfig.from_pretrained(
|
||||||
)
|
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
if not "deepseek_v32" in str(e):
|
||||||
|
raise e
|
||||||
|
config = _load_deepseek_v32_model(
|
||||||
|
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
config.architectures is not None
|
config.architectures is not None
|
||||||
and config.architectures[0] == "Phi4MMForCausalLM"
|
and config.architectures[0] == "Phi4MMForCausalLM"
|
||||||
|
|||||||
@@ -619,11 +619,7 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
assert len(k.shape) == 3
|
assert len(k.shape) == 3
|
||||||
assert len(v.shape) == 3
|
assert len(v.shape) == 3
|
||||||
|
|
||||||
if (
|
if forward_batch.forward_mode.is_extend():
|
||||||
forward_batch.forward_mode.is_extend()
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
|
||||||
):
|
|
||||||
if kv_indices.shape[0] == 0:
|
if kv_indices.shape[0] == 0:
|
||||||
o = flash_attn_varlen_func(
|
o = flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
import custom_ops
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
@@ -36,6 +37,8 @@ class ForwardMetadata:
|
|||||||
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
||||||
seq_lens_cpu_list: Optional[List[int]] = None
|
seq_lens_cpu_list: Optional[List[int]] = None
|
||||||
seq_lens_list_cumsum: Optional[List[int]] = None
|
seq_lens_list_cumsum: Optional[List[int]] = None
|
||||||
|
seq_lens: Optional[torch.Tensor] = None
|
||||||
|
actual_seq_lengths_q: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class AscendAttnBackend(AttentionBackend):
|
class AscendAttnBackend(AttentionBackend):
|
||||||
@@ -67,6 +70,9 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||||
|
self.q_head_dim = (
|
||||||
|
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
|
||||||
|
)
|
||||||
self.native_attn = TorchNativeAttnBackend(model_runner)
|
self.native_attn = TorchNativeAttnBackend(model_runner)
|
||||||
self.graph_metadata = {}
|
self.graph_metadata = {}
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
@@ -102,10 +108,6 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||||
|
|
||||||
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
||||||
if forward_batch.is_extend_in_batch:
|
|
||||||
seq_lens_list_cumsum[-1] = (
|
|
||||||
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
|
|
||||||
) * tp_size
|
|
||||||
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
||||||
|
|
||||||
self.graph_mode = False
|
self.graph_mode = False
|
||||||
@@ -133,6 +135,10 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
||||||
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
||||||
|
metadata.seq_lens = seq_lens
|
||||||
|
metadata.actual_seq_lengths_q = torch.tensor(
|
||||||
|
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
|
||||||
|
)
|
||||||
|
|
||||||
self.graph_metadata[bs] = metadata
|
self.graph_metadata[bs] = metadata
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
@@ -161,6 +167,8 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
||||||
metadata.block_tables[bs:, :].fill_(0)
|
metadata.block_tables[bs:, :].fill_(0)
|
||||||
|
|
||||||
|
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
|
||||||
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
self.graph_mode = True
|
self.graph_mode = True
|
||||||
@@ -168,6 +176,64 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def forward_sparse(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache: bool = True,
|
||||||
|
# For multi_head latent attention
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
topk_indices: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
is_prefill = forward_batch.forward_mode.is_extend()
|
||||||
|
|
||||||
|
if save_kv_cache:
|
||||||
|
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
||||||
|
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, k_rope
|
||||||
|
)
|
||||||
|
q_nope, q_pe = q, q_rope
|
||||||
|
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||||
|
block_table = self.forward_metadata.block_tables
|
||||||
|
if is_prefill:
|
||||||
|
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||||
|
else:
|
||||||
|
if self.forward_metadata.actual_seq_lengths_q is None:
|
||||||
|
actual_seq_qlen = (
|
||||||
|
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
|
||||||
|
if self.forward_metadata.seq_lens_cpu_int is None:
|
||||||
|
actual_seq_lengths_kv = self.forward_metadata.seq_lens
|
||||||
|
else:
|
||||||
|
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
|
||||||
|
|
||||||
|
attn_out = torch.ops.custom.npu_sparse_flash_attention(
|
||||||
|
query=q_nope,
|
||||||
|
key=k_nope,
|
||||||
|
value=k_nope,
|
||||||
|
query_rope=q_pe,
|
||||||
|
key_rope=k_pe,
|
||||||
|
sparse_indices=topk_indices,
|
||||||
|
scale_value=layer.scaling,
|
||||||
|
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
|
||||||
|
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
|
||||||
|
block_table=block_table,
|
||||||
|
sparse_block_size=1,
|
||||||
|
layout_query="TND",
|
||||||
|
layout_kv="PA_BSND",
|
||||||
|
sparse_mode=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_out
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
q,
|
q,
|
||||||
@@ -176,7 +242,23 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
|
# For multi_head latent attention
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
topk_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
if topk_indices is not None:
|
||||||
|
return self.forward_sparse(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
layer,
|
||||||
|
forward_batch,
|
||||||
|
save_kv_cache,
|
||||||
|
q_rope,
|
||||||
|
k_rope,
|
||||||
|
topk_indices,
|
||||||
|
)
|
||||||
if not self.use_mla:
|
if not self.use_mla:
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
@@ -437,10 +519,23 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
# For multi-head latent attention
|
# For multi-head latent attention
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
topk_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if is_mla_preprocess_enabled():
|
if is_mla_preprocess_enabled():
|
||||||
# MLAPO does saving kv_cache
|
# MLAPO does saving kv_cache
|
||||||
save_kv_cache = False
|
save_kv_cache = False
|
||||||
|
if topk_indices is not None:
|
||||||
|
return self.forward_sparse(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
layer,
|
||||||
|
forward_batch,
|
||||||
|
save_kv_cache,
|
||||||
|
q_rope,
|
||||||
|
k_rope,
|
||||||
|
topk_indices,
|
||||||
|
)
|
||||||
|
|
||||||
if self.graph_mode:
|
if self.graph_mode:
|
||||||
return self.forward_decode_graph(
|
return self.forward_decode_graph(
|
||||||
|
|||||||
@@ -1,7 +1,3 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
ATTENTION_BACKENDS = {}
|
ATTENTION_BACKENDS = {}
|
||||||
|
|
||||||
|
|
||||||
@@ -66,6 +62,13 @@ def create_ascend_backend(runner):
|
|||||||
return AscendAttnBackend(runner)
|
return AscendAttnBackend(runner)
|
||||||
|
|
||||||
|
|
||||||
|
@register_attention_backend("nsa")
|
||||||
|
def create_nsa_backend(runner):
|
||||||
|
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
|
||||||
|
|
||||||
|
return NativeSparseAttnBackend(runner)
|
||||||
|
|
||||||
|
|
||||||
@register_attention_backend("triton")
|
@register_attention_backend("triton")
|
||||||
def create_triton_backend(runner):
|
def create_triton_backend(runner):
|
||||||
assert not runner.model_config.is_encoder_decoder, (
|
assert not runner.model_config.is_encoder_decoder, (
|
||||||
@@ -162,37 +165,35 @@ def create_dual_chunk_flash_attn_backend(runner):
|
|||||||
return DualChunkFlashAttentionBackend(runner)
|
return DualChunkFlashAttentionBackend(runner)
|
||||||
|
|
||||||
|
|
||||||
def attn_backend_wrapper(runner, full_attn_backend):
|
@register_attention_backend("hybrid_linear_attn")
|
||||||
"""
|
def create_hybrid_linear_attn_backend(runner):
|
||||||
Wrapper for special models like hybrid GDN, so we don't
|
assert (
|
||||||
need to change the code of the original attention backend.
|
runner.is_hybrid_gdn
|
||||||
"""
|
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
|
||||||
assert not (
|
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||||
runner.is_hybrid_gdn and runner.use_mla_backend
|
HybridLinearAttnBackend,
|
||||||
), "hybrid_gdn can only be used with non-MLA models."
|
MambaAttnBackend,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import is_blackwell, is_npu
|
||||||
|
|
||||||
# wrap for hybrid GDN models
|
if is_npu():
|
||||||
if runner.is_hybrid_gdn:
|
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
||||||
from sglang.srt.utils import is_blackwell, is_npu
|
|
||||||
|
|
||||||
if is_blackwell():
|
full_attn_backend = AscendAttnBackend(runner)
|
||||||
assert (
|
elif is_blackwell():
|
||||||
runner.server_args.attention_backend == "triton"
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
|
|
||||||
if is_npu():
|
full_attn_backend = TritonAttnBackend(runner)
|
||||||
assert (
|
else:
|
||||||
runner.server_args.attention_backend == "ascend"
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
FlashAttentionBackend,
|
||||||
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
|
||||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
|
||||||
HybridLinearAttnBackend,
|
|
||||||
MambaAttnBackend,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
linear_attn_backend = MambaAttnBackend(runner)
|
full_attn_backend = FlashAttentionBackend(runner)
|
||||||
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
|
||||||
return HybridLinearAttnBackend(
|
|
||||||
full_attn_backend, linear_attn_backend, full_attn_layers
|
|
||||||
)
|
|
||||||
|
|
||||||
return full_attn_backend
|
linear_attn_backend = MambaAttnBackend(runner)
|
||||||
|
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
||||||
|
|
||||||
|
return HybridLinearAttnBackend(
|
||||||
|
full_attn_backend, linear_attn_backend, full_attn_layers
|
||||||
|
)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
|
|||||||
def support_triton(self):
|
def support_triton(self):
|
||||||
"""Check if the current backend supports triton."""
|
"""Check if the current backend supports triton."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def get_indexer_metadata(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
) -> Optional[BaseIndexerMetadata]:
|
||||||
|
"""Get the indexer metadata. None means don't support indexer."""
|
||||||
|
return None
|
||||||
|
|||||||
@@ -692,13 +692,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale, v_descale = None, None
|
k_descale, v_descale = None, None
|
||||||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||||
# has corresponding quantization method so that layer.k_scale is not None,
|
# has corresponding quantization method so that layer.k_scale is not None,
|
||||||
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
|
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
||||||
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
|
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
||||||
if (
|
|
||||||
self.kv_cache_dtype_str != "auto"
|
|
||||||
and layer.head_dim <= 256
|
|
||||||
and self.fa_impl_ver != 4
|
|
||||||
):
|
|
||||||
if layer.k_scale is not None:
|
if layer.k_scale is not None:
|
||||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||||
k_descale = layer.k_scale.expand(descale_shape)
|
k_descale = layer.k_scale.expand(descale_shape)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
|
|||||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_int_env_var,
|
get_int_env_var,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
@@ -344,7 +344,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
decode_wrappers = []
|
decode_wrappers = []
|
||||||
@@ -451,7 +453,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
@@ -669,7 +673,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
disable_split_kv: Optional[bool] = None,
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
@@ -684,7 +690,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
disable_split_kv: Optional[bool] = None,
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
@@ -710,7 +718,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
disable_split_kv: Optional[bool] = None,
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
@@ -760,7 +770,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
disable_split_kv: Optional[bool] = None,
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
@@ -794,7 +806,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
kv_start_idx: torch.Tensor,
|
kv_start_idx: torch.Tensor,
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
use_sliding_window_kv_pool: bool = False,
|
use_sliding_window_kv_pool: bool = False,
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
@@ -905,7 +919,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
@@ -921,7 +937,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
@@ -959,7 +977,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
@@ -1006,7 +1026,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
@@ -1049,7 +1071,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
qo_indptr: torch.Tensor,
|
qo_indptr: torch.Tensor,
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
spec_info: Optional[
|
||||||
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
|
],
|
||||||
use_sliding_window_kv_pool: bool = False,
|
use_sliding_window_kv_pool: bool = False,
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
@@ -1078,7 +1102,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
custom_mask = None
|
custom_mask = None
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
|
spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
|
||||||
)
|
)
|
||||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
spec_info.generate_attn_arg_prefill(
|
spec_info.generate_attn_arg_prefill(
|
||||||
|
|||||||
@@ -5,13 +5,20 @@ Support attention backend for FlashMLA.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
|
||||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
||||||
|
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
||||||
|
from sglang.srt.layers.attention.nsa.utils import (
|
||||||
|
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||||
|
NSA_KV_CACHE_STORE_FP8,
|
||||||
|
compute_nsa_seqlens,
|
||||||
|
)
|
||||||
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
@@ -74,10 +81,17 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.scaling = model_runner.model_config.scaling
|
self.scaling = model_runner.model_config.scaling
|
||||||
self.data_type = model_runner.kv_cache_dtype
|
self.data_type = model_runner.kv_cache_dtype
|
||||||
self.q_data_type = model_runner.dtype
|
self.q_data_type = model_runner.dtype
|
||||||
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
|
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
|
||||||
|
|
||||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||||
|
|
||||||
|
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
|
||||||
|
self.nsa_index_topk = (
|
||||||
|
get_nsa_index_topk(model_runner.model_config.hf_config)
|
||||||
|
if self.use_nsa
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
@@ -100,10 +114,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
max_seqlen_pad,
|
max_seqlen_pad,
|
||||||
)
|
)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||||
forward_batch.seq_lens.to(torch.int32),
|
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
|
||||||
self.num_q_heads,
|
seq_len_q=1,
|
||||||
1,
|
num_heads_q=self.num_q_heads,
|
||||||
|
num_heads_k=1,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
)
|
)
|
||||||
self.forward_metadata = FlashMLADecodeMetadata(
|
self.forward_metadata = FlashMLADecodeMetadata(
|
||||||
mla_metadata,
|
mla_metadata,
|
||||||
@@ -130,10 +146,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
max_seqlen_pad,
|
max_seqlen_pad,
|
||||||
)
|
)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||||
seq_lens.to(torch.int32),
|
cache_seqlens=seq_lens.to(torch.int32),
|
||||||
self.num_draft_tokens * self.num_q_heads,
|
seq_len_q=self.num_draft_tokens,
|
||||||
1,
|
num_heads_q=self.num_q_heads,
|
||||||
|
num_heads_k=1,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
|
# Use FlashMLADecodeMetadata which has the attributes forward_extend expects
|
||||||
@@ -162,20 +180,28 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
cuda_graph_kv_indices = block_kv_indices
|
cuda_graph_kv_indices = block_kv_indices
|
||||||
|
|
||||||
if self.num_draft_tokens:
|
if self.num_draft_tokens:
|
||||||
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
|
||||||
torch.ones(
|
_get_mla_metadata_wrapped(
|
||||||
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
|
cache_seqlens=torch.ones(
|
||||||
),
|
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
|
||||||
self.num_draft_tokens * self.num_q_heads,
|
),
|
||||||
1,
|
seq_len_q=self.num_draft_tokens,
|
||||||
|
num_heads_q=self.num_q_heads,
|
||||||
|
num_heads_k=1,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = (
|
||||||
torch.ones(
|
_get_mla_metadata_wrapped(
|
||||||
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
|
cache_seqlens=torch.ones(
|
||||||
),
|
max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device
|
||||||
self.num_q_heads,
|
),
|
||||||
1,
|
seq_len_q=1,
|
||||||
|
num_heads_q=self.num_q_heads,
|
||||||
|
num_heads_k=1,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
||||||
|
|
||||||
@@ -201,10 +227,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
)
|
)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||||
seq_lens.to(torch.int32),
|
cache_seqlens=seq_lens.to(torch.int32),
|
||||||
self.num_q_heads,
|
seq_len_q=1,
|
||||||
1,
|
num_heads_q=self.num_q_heads,
|
||||||
|
num_heads_k=1,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
)
|
)
|
||||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
@@ -226,10 +254,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
)
|
)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||||
seq_lens.to(torch.int32),
|
cache_seqlens=seq_lens.to(torch.int32),
|
||||||
self.num_draft_tokens * self.num_q_heads,
|
seq_len_q=self.num_draft_tokens,
|
||||||
1,
|
num_heads_q=self.num_q_heads,
|
||||||
|
num_heads_k=1,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
)
|
)
|
||||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
@@ -275,10 +305,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
)
|
)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||||
seq_lens.to(torch.int32),
|
cache_seqlens=seq_lens.to(torch.int32),
|
||||||
self.num_q_heads,
|
seq_len_q=1,
|
||||||
1,
|
num_heads_q=self.num_q_heads,
|
||||||
|
num_heads_k=1,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
)
|
)
|
||||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
@@ -300,10 +332,12 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
self.req_to_token.stride(0),
|
self.req_to_token.stride(0),
|
||||||
self.cuda_graph_kv_indices.stride(0),
|
self.cuda_graph_kv_indices.stride(0),
|
||||||
)
|
)
|
||||||
mla_metadata, num_splits = get_mla_metadata(
|
mla_metadata, num_splits = _get_mla_metadata_wrapped(
|
||||||
seq_lens.to(torch.int32),
|
cache_seqlens=seq_lens.to(torch.int32),
|
||||||
self.num_draft_tokens * self.num_q_heads,
|
seq_len_q=self.num_draft_tokens,
|
||||||
1,
|
num_heads_q=self.num_q_heads,
|
||||||
|
num_heads_k=1,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
)
|
)
|
||||||
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
||||||
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
||||||
@@ -335,6 +369,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
|
topk_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
cache_loc = forward_batch.out_cache_loc
|
cache_loc = forward_batch.out_cache_loc
|
||||||
|
|
||||||
@@ -349,13 +384,14 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
)
|
)
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
k_cache = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
|
||||||
|
|
||||||
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||||
if self.data_type == torch.float8_e4m3fn:
|
if (not self.use_nsa) and self.data_type == torch.float8_e4m3fn:
|
||||||
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
|
reshape_q_fp8 = reshape_q.to(torch.float8_e4m3fn)
|
||||||
o, _ = flash_mla_with_kvcache(
|
o, _ = flash_mla_with_kvcache(
|
||||||
q=reshape_q_fp8,
|
q=reshape_q_fp8,
|
||||||
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
|
k_cache=k_cache,
|
||||||
block_table=self.forward_metadata.block_kv_indices[:bs],
|
block_table=self.forward_metadata.block_kv_indices[:bs],
|
||||||
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
|
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
|
||||||
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
|
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
|
||||||
@@ -369,17 +405,49 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
else:
|
else:
|
||||||
|
block_table = self.forward_metadata.block_kv_indices[:bs]
|
||||||
|
cache_seqlens = forward_batch.seq_lens.to(torch.int32)
|
||||||
|
|
||||||
|
extra_kwargs: Dict
|
||||||
|
if self.use_nsa:
|
||||||
|
assert topk_indices is not None
|
||||||
|
extra_kwargs = dict(
|
||||||
|
indices=_compute_indices_in_kvcache(
|
||||||
|
block_table=block_table,
|
||||||
|
topk_indices=topk_indices.to(torch.int32),
|
||||||
|
page_size=self.page_size,
|
||||||
|
),
|
||||||
|
# doc says it is not used, but if pass in None then error
|
||||||
|
block_table=block_table,
|
||||||
|
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||||
|
)
|
||||||
|
cache_seqlens = compute_nsa_seqlens(
|
||||||
|
cache_seqlens, nsa_index_topk=self.nsa_index_topk
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extra_kwargs = dict(
|
||||||
|
block_table=block_table,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.use_nsa
|
||||||
|
and NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8
|
||||||
|
and not NSA_KV_CACHE_STORE_FP8
|
||||||
|
):
|
||||||
|
# inefficiently quantize the whole cache
|
||||||
|
k_cache = quantize_k_cache(k_cache)
|
||||||
|
|
||||||
# todo: need check all causal True or False?
|
# todo: need check all causal True or False?
|
||||||
o, _ = flash_mla_with_kvcache(
|
o, _ = flash_mla_with_kvcache(
|
||||||
q=reshape_q,
|
q=reshape_q,
|
||||||
k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
|
k_cache=k_cache,
|
||||||
block_table=self.forward_metadata.block_kv_indices[:bs],
|
cache_seqlens=cache_seqlens,
|
||||||
cache_seqlens=forward_batch.seq_lens.to(torch.int32),
|
|
||||||
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
|
head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
|
||||||
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
|
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
|
||||||
num_splits=self.forward_metadata.num_splits,
|
num_splits=self.forward_metadata.num_splits,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
@@ -539,3 +607,52 @@ class FlashMLAMultiStepDraftBackend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.common_template(forward_batch, call_fn)
|
self.common_template(forward_batch, call_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mla_metadata_wrapped(
|
||||||
|
*,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
seq_len_q: int,
|
||||||
|
num_heads_q: int,
|
||||||
|
num_heads_k: int,
|
||||||
|
nsa_index_topk: Optional[int],
|
||||||
|
):
|
||||||
|
if nsa_index_topk is not None:
|
||||||
|
assert nsa_index_topk is not None
|
||||||
|
return get_mla_metadata(
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
|
||||||
|
# but the name looks like need seq_len_q?
|
||||||
|
num_q_tokens_per_head_k=seq_len_q * num_heads_q // num_heads_k,
|
||||||
|
num_heads_k=num_heads_k,
|
||||||
|
num_heads_q=num_heads_q,
|
||||||
|
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||||
|
topk=nsa_index_topk,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert nsa_index_topk is None
|
||||||
|
return get_mla_metadata(
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
num_heads_per_head_k=seq_len_q * num_heads_q // num_heads_k,
|
||||||
|
num_heads_k=num_heads_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO speedup
|
||||||
|
def _compute_indices_in_kvcache(block_table, topk_indices, page_size):
|
||||||
|
topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
|
||||||
|
|
||||||
|
idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
block_idx = block_table[idx0, topk_indices_safe // page_size]
|
||||||
|
offset = topk_indices_safe % page_size
|
||||||
|
indices_in_kvcache = block_idx * page_size + offset
|
||||||
|
|
||||||
|
# the kernel requires invalid entry to be -1
|
||||||
|
assert indices_in_kvcache.shape == topk_indices.shape
|
||||||
|
indices_in_kvcache[topk_indices == -1] = -1
|
||||||
|
|
||||||
|
# return: (batch_size, seqlen_q_ori, topk)
|
||||||
|
indices_in_kvcache = indices_in_kvcache[:, None, :]
|
||||||
|
return indices_in_kvcache
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
@@ -138,3 +139,9 @@ class HybridAttnBackend(AttentionBackend):
|
|||||||
return backend.forward_extend(
|
return backend.forward_extend(
|
||||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_indexer_metadata(
|
||||||
|
self, layer_id: int, forward_batch: ForwardBatch
|
||||||
|
) -> Optional[BaseIndexerMetadata]:
|
||||||
|
backend = self._select_backend(forward_batch.forward_mode)
|
||||||
|
return backend.get_indexer_metadata(layer_id, forward_batch)
|
||||||
|
|||||||
121
python/sglang/srt/layers/attention/native_mla.py
Normal file
121
python/sglang/srt/layers/attention/native_mla.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def cdiv(x: int, y: int):
|
||||||
|
return (x+y-1) // y
|
||||||
|
|
||||||
|
def native_mla_sparse_fwd(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
d_v: int = 512,) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
s_q, _, d_qk = q.size()
|
||||||
|
s_kv = kv.size(0)
|
||||||
|
topk = indices.size(-1)
|
||||||
|
|
||||||
|
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
|
||||||
|
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
|
||||||
|
|
||||||
|
indices = indices[:, 0, :] # [s_q, topk]
|
||||||
|
invalid_indices_mask = (indices < 0) | (indices >= s_kv)
|
||||||
|
qs = q.float() # [s_q, h_q, d_qk]
|
||||||
|
kvs = kv[ :, 0, :].float() # [s_kv, d_qk]
|
||||||
|
|
||||||
|
kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(s_q, topk, d_qk) # [s_q, topk, d_qk]
|
||||||
|
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
|
||||||
|
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf'))
|
||||||
|
attn_score *= sm_scale * math.log2(math.e)
|
||||||
|
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
|
||||||
|
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
|
||||||
|
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
|
||||||
|
result = attn_score @ kvs[:, :, :d_v]
|
||||||
|
return (max_logits, lse, result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def native_mla_with_kvcache(
|
||||||
|
q: torch.Tensor, # [batch_size, s_q, h_q, d]
|
||||||
|
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
|
||||||
|
block_table: torch.Tensor, # [batch_size, ?]
|
||||||
|
cache_seqlens: torch.Tensor, # [batch_size]
|
||||||
|
dv: int,
|
||||||
|
is_causal: bool,
|
||||||
|
indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
A reference implementation in PyTorch
|
||||||
|
"""
|
||||||
|
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
|
||||||
|
mask = torch.zeros(s_q, s_k, dtype=torch.bool)
|
||||||
|
for i in range(s_q):
|
||||||
|
cur_indices = indices[i]
|
||||||
|
valid_indices = cur_indices[cur_indices != -1]
|
||||||
|
mask[i, valid_indices] = True
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(
|
||||||
|
batch_idx: int,
|
||||||
|
query: torch.Tensor, # [h_q, s_q, d]
|
||||||
|
kv: torch.Tensor, # [h_kv, s_k, d]
|
||||||
|
dv: int,
|
||||||
|
is_causal,
|
||||||
|
indices: Optional[torch.Tensor], # [s_q, topk]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
h_q = query.size(0)
|
||||||
|
h_kv = kv.size(0)
|
||||||
|
s_q = query.shape[-2]
|
||||||
|
s_k = kv.shape[-2]
|
||||||
|
query = query.float()
|
||||||
|
kv = kv.float()
|
||||||
|
if h_kv != 1:
|
||||||
|
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
|
||||||
|
kv[kv != kv] = 0.0
|
||||||
|
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
|
||||||
|
if (is_causal and query.size(1) > 1) or indices is not None:
|
||||||
|
mask = torch.ones(s_q, s_k, dtype=torch.bool)
|
||||||
|
if is_causal:
|
||||||
|
assert indices is None
|
||||||
|
mask = mask.tril(diagonal=s_k - s_q)
|
||||||
|
if indices is not None:
|
||||||
|
mask &= get_topk_attn_mask(s_q, s_k, indices)
|
||||||
|
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
|
||||||
|
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
|
||||||
|
attn_weight += attn_bias.to(q.dtype)
|
||||||
|
attn_weight /= math.sqrt(query.size(-1))
|
||||||
|
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
|
||||||
|
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||||
|
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
|
||||||
|
# Correct for q tokens which has no attendable k
|
||||||
|
lonely_q_mask = (lse == float("-inf"))
|
||||||
|
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
|
||||||
|
lse[lonely_q_mask] = float("+inf")
|
||||||
|
|
||||||
|
return output, lse
|
||||||
|
|
||||||
|
b, s_q, h_q, d = q.size()
|
||||||
|
block_size = blocked_k.size(1)
|
||||||
|
h_kv = blocked_k.size(2)
|
||||||
|
cache_seqlens_cpu = cache_seqlens.cpu()
|
||||||
|
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||||
|
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||||
|
for i in range(b):
|
||||||
|
cur_len = cache_seqlens_cpu[i].item()
|
||||||
|
cur_num_blocks = cdiv(cur_len, block_size)
|
||||||
|
cur_block_indices = block_table[i][0: cur_num_blocks]
|
||||||
|
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
|
||||||
|
cur_out, cur_lse = scaled_dot_product_attention(
|
||||||
|
i,
|
||||||
|
q[i].transpose(0, 1),
|
||||||
|
cur_kv.transpose(0, 1),
|
||||||
|
dv,
|
||||||
|
is_causal,
|
||||||
|
indices[i] if indices is not None else None
|
||||||
|
)
|
||||||
|
out_ref[i] = cur_out.transpose(0, 1)
|
||||||
|
lse_ref[i] = cur_lse
|
||||||
|
out_ref = out_ref.to(torch.bfloat16)
|
||||||
|
return out_ref, lse_ref
|
||||||
@@ -76,12 +76,14 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
|
|||||||
self.rotary_emb = rotary_emb
|
self.rotary_emb = rotary_emb
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.has_preprocess_weights = False
|
self.has_preprocess_weights = False
|
||||||
|
self.dtype = None
|
||||||
|
|
||||||
self.q_lora_rank = self.q_b_proj.input_size # 1536
|
self.q_lora_rank = self.q_b_proj.input_size # 1536
|
||||||
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
|
self.kv_lora_rank = self.kv_a_layernorm.hidden_size # 512
|
||||||
self.num_local_heads = num_local_heads # tp
|
self.num_local_heads = num_local_heads # tp
|
||||||
self.qk_nope_head_dim = qk_nope_head_dim # 128
|
self.qk_nope_head_dim = qk_nope_head_dim # 128
|
||||||
self.qk_rope_head_dim = qk_rope_head_dim # 64
|
self.qk_rope_head_dim = qk_rope_head_dim # 64
|
||||||
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
|
||||||
def preprocess_weights(self, hidden_states):
|
def preprocess_weights(self, hidden_states):
|
||||||
self.dummy = torch.empty(
|
self.dummy = torch.empty(
|
||||||
@@ -236,7 +238,83 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
|
|||||||
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
|
slot_mapping = forward_batch.out_cache_loc.to(dtype=torch.int32)
|
||||||
return k_cache, v_cache, slot_mapping
|
return k_cache, v_cache, slot_mapping
|
||||||
|
|
||||||
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
|
def forward_absorb_prepare_npu_rms_norm_cache(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch,
|
||||||
|
zero_allocator,
|
||||||
|
):
|
||||||
|
bsz, _ = hidden_states.view(-1, hidden_states.shape[-1]).shape
|
||||||
|
self.dtype = hidden_states.dtype
|
||||||
|
self.cos, self.sin = self.get_sin_cos(positions)
|
||||||
|
self.kvCache, self.kvCacheRope, self.slotmapping = (
|
||||||
|
self.get_kv_cache_and_cache_idx(forward_batch)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.has_preprocess_weights:
|
||||||
|
self.has_preprocess_weights = True
|
||||||
|
|
||||||
|
cos, sin = self.cos, self.sin
|
||||||
|
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
|
fused_qkv_a_proj_out = self.qkv_a_proj(hidden_states)[0]
|
||||||
|
q_lowrank, latent_cache = fused_qkv_a_proj_out.split(
|
||||||
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
q = self.q_a_layernorm(q_lowrank)
|
||||||
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||||
|
else:
|
||||||
|
q = self.q_proj(hidden_states)[0].view(
|
||||||
|
-1, self.num_local_heads, self.qk_head_dim
|
||||||
|
)
|
||||||
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||||
|
|
||||||
|
q_nope, q_pe = torch.split(
|
||||||
|
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
) # b*s,n,d
|
||||||
|
|
||||||
|
q_nope = q_nope.view(-1, self.num_local_heads, self.qk_nope_head_dim)
|
||||||
|
q_nope = torch.matmul(q_nope.transpose(0, 1), self.w_kc).transpose(0, 1)
|
||||||
|
|
||||||
|
q_pe = q_pe.view(-1, self.num_local_heads, 1, self.qk_rope_head_dim)
|
||||||
|
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin) # (B,N,S,D)
|
||||||
|
q_pe = q_pe.view(cos.shape[0], self.num_local_heads, self.qk_rope_head_dim)
|
||||||
|
|
||||||
|
latent_cache = latent_cache.view(
|
||||||
|
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
||||||
|
) # (B*S,N,1,D)
|
||||||
|
|
||||||
|
cache_mode = "PA_BNSD"
|
||||||
|
self.kvCache = self.kvCache.view(
|
||||||
|
-1,
|
||||||
|
forward_batch.attn_backend.page_size,
|
||||||
|
1,
|
||||||
|
forward_batch.attn_backend.kv_lora_rank,
|
||||||
|
)
|
||||||
|
self.kvCacheRope = self.kvCacheRope.view(
|
||||||
|
-1,
|
||||||
|
forward_batch.attn_backend.page_size,
|
||||||
|
1,
|
||||||
|
forward_batch.attn_backend.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
k_rope, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
|
latent_cache,
|
||||||
|
self.kv_a_layernorm.weight,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
self.slotmapping.to(torch.int64),
|
||||||
|
self.kvCacheRope,
|
||||||
|
self.kvCache,
|
||||||
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||||
|
cache_mode=cache_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (q_pe, k_rope, q_nope, k_nope, forward_batch, zero_allocator, positions)
|
||||||
|
|
||||||
|
def forward_mlapo(self, positions, hidden_states, forward_batch, zero_allocator):
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
if not self.has_preprocess_weights:
|
if not self.has_preprocess_weights:
|
||||||
self.preprocess_weights(hidden_states)
|
self.preprocess_weights(hidden_states)
|
||||||
@@ -298,3 +376,18 @@ class NPUFusedMLAPreprocess(torch.nn.Module):
|
|||||||
zero_allocator,
|
zero_allocator,
|
||||||
positions,
|
positions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward(self, positions, hidden_states, forward_batch, zero_allocator):
|
||||||
|
_is_w8a8 = (
|
||||||
|
hasattr(self.qkv_a_proj.quant_method, "quantization_config")
|
||||||
|
and self.qkv_a_proj.quant_method.quantization_config.get_name()
|
||||||
|
== "w8a8_int8"
|
||||||
|
)
|
||||||
|
if _is_w8a8:
|
||||||
|
return self.forward_mlapo(
|
||||||
|
positions, hidden_states, forward_batch, zero_allocator
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.forward_absorb_prepare_npu_rms_norm_cache(
|
||||||
|
positions, hidden_states, forward_batch, zero_allocator
|
||||||
|
)
|
||||||
|
|||||||
3
python/sglang/srt/layers/attention/nsa/cuda/__init__.py
Normal file
3
python/sglang/srt/layers/attention/nsa/cuda/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .topk import fast_topk, fast_topk_transform
|
||||||
|
|
||||||
|
__all__ = ["fast_topk", "fast_topk_transform"]
|
||||||
505
python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu
Normal file
505
python/sglang/srt/layers/attention/nsa/cuda/csrc/topk.cu
Normal file
@@ -0,0 +1,505 @@
|
|||||||
|
#include <ATen/core/TensorBase.h>
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <optional>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
#include <torch/python.h>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int TopK = 2048;
|
||||||
|
constexpr int kThreadsPerBlock = 1024;
|
||||||
|
constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB
|
||||||
|
|
||||||
|
struct FastTopKParams {
|
||||||
|
const float *__restrict__ input; // [B, input_stride]
|
||||||
|
int32_t *__restrict__ indices; // [B, TopK]
|
||||||
|
int32_t *__restrict__ lengths; // [B]
|
||||||
|
int64_t input_stride;
|
||||||
|
bool use_tilelang;
|
||||||
|
};
|
||||||
|
|
||||||
|
// when length <= TopK, we can directly write the indices
|
||||||
|
__device__ void naive_topk_cuda(const float *__restrict__ score,
|
||||||
|
int32_t *__restrict__ indice, int32_t length) {
|
||||||
|
const auto tid = threadIdx.x;
|
||||||
|
for (int i = tid; i < TopK; i += kThreadsPerBlock) {
|
||||||
|
indice[i] = (i < length) ? i : -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// keep the first `length` entries, set others to -1
|
||||||
|
__device__ void
|
||||||
|
naive_topk_transform(const float *__restrict__ score, int32_t length,
|
||||||
|
int32_t *__restrict__ dst_page_table,
|
||||||
|
const int32_t *__restrict__ src_page_table) {
|
||||||
|
const auto tid = threadIdx.x;
|
||||||
|
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
|
||||||
|
dst_page_table[i] = (i < length) ? src_page_table[i] : -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint8_t convert_to_uint8(float x) {
|
||||||
|
__half h = __float2half_rn(x);
|
||||||
|
uint16_t bits = __half_as_ushort(h);
|
||||||
|
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits & 0xFFFF)
|
||||||
|
: static_cast<uint16_t>(bits | 0x8000);
|
||||||
|
return static_cast<uint8_t>(key >> 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint32_t convert_to_uint32(float x) {
|
||||||
|
uint32_t bits = __float_as_uint(x);
|
||||||
|
return (bits & 0x80000000u) ? (~bits & 0xFFFFFFFFu) : (bits | 0x80000000u);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool Is_Epilogue = false, typename Indexer, typename Loader,
|
||||||
|
int LENGTH, int MAX_REMAIN>
|
||||||
|
__device__ __forceinline__ auto
|
||||||
|
radix_topk(Indexer indexer, Loader loader, uint32_t length, int topk,
|
||||||
|
int *__restrict__ index, int &__restrict__ s_counter,
|
||||||
|
int (&__restrict__ s_histogram)[LENGTH],
|
||||||
|
int &__restrict__ s_remain_cnt,
|
||||||
|
int (&__restrict__ s_remain_idx)[MAX_REMAIN]) -> int {
|
||||||
|
constexpr auto RADIX = LENGTH - 1;
|
||||||
|
static_assert(RADIX > 1 && (RADIX & (RADIX - 1)) == 0,
|
||||||
|
"RADIX must be power of 2");
|
||||||
|
static_assert(RADIX <= kThreadsPerBlock);
|
||||||
|
__shared__ uint32_t s_threshold_bin_id;
|
||||||
|
|
||||||
|
const auto tx = threadIdx.x;
|
||||||
|
if (tx < RADIX + 1)
|
||||||
|
s_histogram[tx] = 0;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
/// NOTE: Use uint32_t as the index
|
||||||
|
for (auto i = tx; i < length; i += kThreadsPerBlock) {
|
||||||
|
const auto idx = indexer(i);
|
||||||
|
const auto bin = loader(idx);
|
||||||
|
::atomicAdd(&s_histogram[bin], 1);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// cumsum (descending)
|
||||||
|
if (tx == 0) {
|
||||||
|
s_histogram[RADIX] = 0;
|
||||||
|
s_remain_cnt = 0;
|
||||||
|
for (int i = RADIX - 2; i >= 0; --i) {
|
||||||
|
s_histogram[i] += s_histogram[i + 1];
|
||||||
|
}
|
||||||
|
// threshold bin
|
||||||
|
for (int i = 0; i < RADIX; i++) {
|
||||||
|
if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
|
||||||
|
s_threshold_bin_id = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const auto threshold_bin = s_threshold_bin_id;
|
||||||
|
const auto new_topk = topk - s_histogram[threshold_bin + 1];
|
||||||
|
|
||||||
|
for (auto i = tx; i < length; i += kThreadsPerBlock) {
|
||||||
|
const auto idx = indexer(i);
|
||||||
|
const auto bin_id = static_cast<uint32_t>(loader(idx));
|
||||||
|
if (bin_id > threshold_bin) {
|
||||||
|
index[::atomicAdd(&s_counter, 1)] = idx;
|
||||||
|
} else if (bin_id == threshold_bin && new_topk > 0) {
|
||||||
|
if constexpr (Is_Epilogue) {
|
||||||
|
index[::atomicAdd(&s_counter, 1)] = idx;
|
||||||
|
} else {
|
||||||
|
if (const auto cnt = ::atomicAdd(&s_remain_cnt, 1);
|
||||||
|
C10_LIKELY(cnt < MAX_REMAIN)) {
|
||||||
|
s_remain_idx[cnt] = idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
return new_topk;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void fast_topk_cuda(const float *__restrict__ input,
|
||||||
|
int *__restrict__ index, int length,
|
||||||
|
int topk = TopK) {
|
||||||
|
constexpr auto RADIX = 256;
|
||||||
|
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
|
||||||
|
|
||||||
|
__shared__ int s_histogram[RADIX + 1];
|
||||||
|
__shared__ int s_num_input[2];
|
||||||
|
__shared__ int s_counter;
|
||||||
|
|
||||||
|
// allocate for two rounds
|
||||||
|
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
|
||||||
|
s_counter = 0;
|
||||||
|
|
||||||
|
// collect candidates
|
||||||
|
const auto indexer = [](int idx) { return idx; };
|
||||||
|
const auto loader = [&input](int idx) {
|
||||||
|
return convert_to_uint8(input[idx]);
|
||||||
|
};
|
||||||
|
int new_topk = radix_topk(indexer, loader, length, topk, index, s_counter,
|
||||||
|
s_histogram, s_num_input[0], s_input_idx[0]);
|
||||||
|
if (new_topk <= 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// round 0
|
||||||
|
const auto indexer_0 = [](int idx) { return s_input_idx[0][idx]; };
|
||||||
|
const auto loader_0 = [&input](int idx) {
|
||||||
|
return (convert_to_uint32(input[idx]) >> 24) & 0xFF;
|
||||||
|
};
|
||||||
|
new_topk = radix_topk(indexer_0, loader_0, s_num_input[0], new_topk, index,
|
||||||
|
s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
|
||||||
|
if (new_topk <= 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// round 1
|
||||||
|
const auto indexer_1 = [](int idx) { return s_input_idx[1][idx]; };
|
||||||
|
const auto loader_1 = [&input](int idx) {
|
||||||
|
return (convert_to_uint32(input[idx]) >> 16) & 0xFF;
|
||||||
|
};
|
||||||
|
new_topk = radix_topk(indexer_1, loader_1, s_num_input[1], new_topk, index,
|
||||||
|
s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
|
||||||
|
if (new_topk <= 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// round 2
|
||||||
|
const auto loader_2 = [&input](int idx) {
|
||||||
|
return (convert_to_uint32(input[idx]) >> 8) & 0xFF;
|
||||||
|
};
|
||||||
|
new_topk = radix_topk(indexer_0, loader_2, s_num_input[0], new_topk, index,
|
||||||
|
s_counter, s_histogram, s_num_input[1], s_input_idx[1]);
|
||||||
|
if (new_topk <= 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// round 3
|
||||||
|
const auto loader_3 = [&input](int idx) {
|
||||||
|
return convert_to_uint32(input[idx]) & 0xFF;
|
||||||
|
};
|
||||||
|
// epilogue
|
||||||
|
radix_topk<true>(indexer_1, loader_3, s_num_input[1], new_topk, index,
|
||||||
|
s_counter, s_histogram, s_num_input[0], s_input_idx[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ void fast_topk_cuda_tl(const float *__restrict__ input,
|
||||||
|
int *__restrict__ index, int length,
|
||||||
|
int topk = TopK) {
|
||||||
|
constexpr auto BLOCK_SIZE = 1024;
|
||||||
|
constexpr auto RADIX = 256;
|
||||||
|
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
|
||||||
|
|
||||||
|
__shared__ int s_threshold_bin_id;
|
||||||
|
__shared__ int s_histogram[RADIX + 1];
|
||||||
|
__shared__ int s_num_input[2];
|
||||||
|
__shared__ int s_counter;
|
||||||
|
|
||||||
|
// allocate for two rounds
|
||||||
|
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
|
||||||
|
|
||||||
|
int tx = threadIdx.x;
|
||||||
|
|
||||||
|
// stage 1: 8bit coarse histogram
|
||||||
|
if (tx < RADIX + 1)
|
||||||
|
s_histogram[tx] = 0;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
||||||
|
const auto bin = convert_to_uint8(input[idx]);
|
||||||
|
::atomicAdd(&s_histogram[bin], 1);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// cumsum (descending)
|
||||||
|
if (tx == 0) {
|
||||||
|
for (int i = RADIX - 2; i >= 0; --i) {
|
||||||
|
s_histogram[i] += s_histogram[i + 1];
|
||||||
|
}
|
||||||
|
// threshold bin
|
||||||
|
for (int i = 0; i < RADIX; i++) {
|
||||||
|
if (s_histogram[i] >= topk && s_histogram[i + 1] < topk) {
|
||||||
|
s_threshold_bin_id = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s_num_input[0] = 0;
|
||||||
|
s_counter = 0;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int threshold_bin = s_threshold_bin_id;
|
||||||
|
int new_topk = topk - s_histogram[threshold_bin + 1];
|
||||||
|
|
||||||
|
// collect candidates
|
||||||
|
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
||||||
|
const auto bin_id = static_cast<int>(convert_to_uint8(input[idx]));
|
||||||
|
if (bin_id > threshold_bin) {
|
||||||
|
int pos = ::atomicAdd(&s_counter, 1);
|
||||||
|
index[pos] = idx;
|
||||||
|
} else if (bin_id == threshold_bin && new_topk > 0) {
|
||||||
|
int pos = ::atomicAdd(&s_num_input[0], 1);
|
||||||
|
if (pos < SMEM_INPUT_SIZE) {
|
||||||
|
[[likely]] s_input_idx[0][pos] = idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// stage 2: refine with 8bit radix passes
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int round = 0; round < 4; ++round) {
|
||||||
|
if (new_topk <= 0)
|
||||||
|
break;
|
||||||
|
int r_idx = round % 2;
|
||||||
|
|
||||||
|
// reset
|
||||||
|
if (tx < RADIX + 1)
|
||||||
|
s_histogram[tx] = 0;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int num_input = s_num_input[r_idx];
|
||||||
|
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
|
||||||
|
int idx = s_input_idx[r_idx][i];
|
||||||
|
uint32_t bin32 =
|
||||||
|
(convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
|
||||||
|
::atomicAdd(&s_histogram[bin32], 1);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (tx == 0) {
|
||||||
|
for (int i = RADIX - 2; i >= 0; --i)
|
||||||
|
s_histogram[i] += s_histogram[i + 1];
|
||||||
|
for (int i = 0; i < RADIX; i++) {
|
||||||
|
if (s_histogram[i] >= new_topk && s_histogram[i + 1] < new_topk) {
|
||||||
|
s_threshold_bin_id = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s_num_input[r_idx ^ 1] = 0;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
new_topk -= s_histogram[s_threshold_bin_id + 1];
|
||||||
|
int threshold_bin = s_threshold_bin_id;
|
||||||
|
|
||||||
|
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
|
||||||
|
int idx = s_input_idx[r_idx][i];
|
||||||
|
uint32_t bin32 =
|
||||||
|
(convert_to_uint32(input[idx]) >> (24 - round * 8)) & 0xFF;
|
||||||
|
if (bin32 > threshold_bin) {
|
||||||
|
int pos = ::atomicAdd(&s_counter, 1);
|
||||||
|
index[pos] = idx;
|
||||||
|
} else if (bin32 == threshold_bin && new_topk > 0) {
|
||||||
|
if (round == 3) {
|
||||||
|
int pos = ::atomicAdd(&s_counter, 1);
|
||||||
|
index[pos] = idx;
|
||||||
|
} else {
|
||||||
|
int pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1);
|
||||||
|
if (pos < SMEM_INPUT_SIZE)
|
||||||
|
s_input_idx[r_idx ^ 1][pos] = idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void topk_kernel(const FastTopKParams params) {
|
||||||
|
const auto &[input, indices, lengths, input_stride, use_tilelang] = params;
|
||||||
|
const auto bid = blockIdx.x;
|
||||||
|
const auto length = *(lengths + bid);
|
||||||
|
const auto indice = indices + bid * TopK;
|
||||||
|
const auto score = input + bid * input_stride;
|
||||||
|
if (length <= TopK) {
|
||||||
|
return naive_topk_cuda(score, indice, length);
|
||||||
|
} else {
|
||||||
|
if (use_tilelang) {
|
||||||
|
return fast_topk_cuda_tl(score, indice, length);
|
||||||
|
} else {
|
||||||
|
return fast_topk_cuda(score, indice, length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void topk_kernel_transform_decode( // decode
|
||||||
|
const FastTopKParams params, int32_t *__restrict__ dst_page_table,
|
||||||
|
const int32_t *__restrict__ src_page_table, const int64_t src_stride) {
|
||||||
|
const auto &[input, _, lengths, input_stride, use_tilelang] = params;
|
||||||
|
const auto bid = blockIdx.x;
|
||||||
|
const auto tid = threadIdx.x;
|
||||||
|
const auto length = *(lengths + bid);
|
||||||
|
const auto src_page_entry = src_page_table + bid * src_stride;
|
||||||
|
const auto dst_page_entry = dst_page_table + bid * TopK;
|
||||||
|
const auto score = input + bid * input_stride;
|
||||||
|
if (length <= TopK) {
|
||||||
|
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
|
||||||
|
} else {
|
||||||
|
__shared__ int s_indices[TopK];
|
||||||
|
if (use_tilelang) {
|
||||||
|
fast_topk_cuda_tl(score, s_indices, length);
|
||||||
|
} else {
|
||||||
|
fast_topk_cuda(score, s_indices, length);
|
||||||
|
}
|
||||||
|
// copy src[s_indices] to dst, we manually unroll here
|
||||||
|
static_assert(TopK % kThreadsPerBlock == 0);
|
||||||
|
static_assert(TopK / kThreadsPerBlock == 2);
|
||||||
|
const auto idx_0 = tid;
|
||||||
|
const auto pos_0 = s_indices[idx_0];
|
||||||
|
dst_page_entry[idx_0] = src_page_entry[pos_0];
|
||||||
|
const auto idx_1 = tid + kThreadsPerBlock;
|
||||||
|
const auto pos_1 = s_indices[idx_1];
|
||||||
|
dst_page_entry[idx_1] = src_page_entry[pos_1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void topk_kernel_transform_prefill( // prefill
|
||||||
|
const FastTopKParams params, int32_t *__restrict__ dst_page_table,
|
||||||
|
const int32_t *__restrict__ src_page_table, const int64_t src_stride,
|
||||||
|
const int32_t *__restrict__ cu_seqlens, const int64_t prefill_bs) {
|
||||||
|
const auto &[input, _, lengths, input_stride, use_tilelang] = params;
|
||||||
|
const auto bid = blockIdx.x;
|
||||||
|
const auto tid = threadIdx.x;
|
||||||
|
const auto length = *(lengths + bid);
|
||||||
|
const auto dst_page_entry = dst_page_table + bid * TopK;
|
||||||
|
const auto score = input + bid * input_stride;
|
||||||
|
|
||||||
|
/// NOTE: prefill bs is usually small, we can just use a simple loop here
|
||||||
|
/// We ensure that last cu_seqlens is equal to number of blocks launched
|
||||||
|
assert(gridDim.x == cu_seqlens[prefill_bs] &&
|
||||||
|
"Invalid cu_seqlens in topk-transform-prefill");
|
||||||
|
__shared__ const int32_t *s_src_page_entry;
|
||||||
|
if (tid == 0) {
|
||||||
|
for (int64_t offset = 0; offset < prefill_bs; ++offset) {
|
||||||
|
if (bid < cu_seqlens[offset + 1]) {
|
||||||
|
s_src_page_entry = src_page_table + offset * src_stride;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
const auto src_page_entry = s_src_page_entry;
|
||||||
|
|
||||||
|
if (length <= TopK) {
|
||||||
|
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
|
||||||
|
} else {
|
||||||
|
__shared__ int s_indices[TopK];
|
||||||
|
if (use_tilelang) {
|
||||||
|
fast_topk_cuda_tl(score, s_indices, length);
|
||||||
|
} else {
|
||||||
|
fast_topk_cuda(score, s_indices, length);
|
||||||
|
}
|
||||||
|
// copy src[s_indices] to dst, we manually unroll here
|
||||||
|
static_assert(TopK % kThreadsPerBlock == 0);
|
||||||
|
static_assert(TopK / kThreadsPerBlock == 2);
|
||||||
|
const auto idx_0 = tid;
|
||||||
|
const auto pos_0 = s_indices[idx_0];
|
||||||
|
dst_page_entry[idx_0] = src_page_entry[pos_0];
|
||||||
|
const auto idx_1 = tid + kThreadsPerBlock;
|
||||||
|
const auto pos_1 = s_indices[idx_1];
|
||||||
|
dst_page_entry[idx_1] = src_page_entry[pos_1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_params(at::Tensor score, at::Tensor lengths, bool use_tilelang,
|
||||||
|
std::optional<at::Tensor> indices_opt = std::nullopt)
|
||||||
|
-> FastTopKParams {
|
||||||
|
const auto B = score.size(0);
|
||||||
|
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
|
||||||
|
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
|
||||||
|
TORCH_CHECK(lengths.size(0) == B);
|
||||||
|
int32_t *indices_data_ptr = nullptr;
|
||||||
|
if (indices_opt.has_value()) {
|
||||||
|
const auto &indices = indices_opt.value();
|
||||||
|
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous());
|
||||||
|
TORCH_CHECK(indices.size(0) == B);
|
||||||
|
TORCH_CHECK(indices.size(1) == TopK);
|
||||||
|
indices_data_ptr = indices.data_ptr<int32_t>();
|
||||||
|
}
|
||||||
|
|
||||||
|
return FastTopKParams{
|
||||||
|
.input = score.data_ptr<float>(),
|
||||||
|
.indices = indices_data_ptr,
|
||||||
|
.lengths = lengths.data_ptr<int32_t>(),
|
||||||
|
.input_stride = score.stride(0),
|
||||||
|
.use_tilelang = use_tilelang,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <auto *f, size_t max_dynamic_smem>
|
||||||
|
auto setup_kernel_smem_once() -> void {
|
||||||
|
[[maybe_unused]]
|
||||||
|
static const auto result = [] {
|
||||||
|
return ::cudaFuncSetAttribute(
|
||||||
|
f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem);
|
||||||
|
}();
|
||||||
|
TORCH_CHECK(result == cudaSuccess,
|
||||||
|
"set_up_kernel_once failed:", ::cudaGetErrorString(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fast_topk_interface(at::Tensor score, at::Tensor indices,
|
||||||
|
at::Tensor lengths, bool use_tilelang) -> void {
|
||||||
|
const auto params = get_params(score, lengths, use_tilelang, indices);
|
||||||
|
const auto B = score.size(0);
|
||||||
|
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
const auto grid = dim3{static_cast<uint32_t>(B)};
|
||||||
|
const auto block = dim3{kThreadsPerBlock};
|
||||||
|
setup_kernel_smem_once<topk_kernel, kSmem>();
|
||||||
|
topk_kernel<<<grid, block, kSmem, stream>>>(params);
|
||||||
|
const auto result = cudaGetLastError();
|
||||||
|
TORCH_CHECK(result == cudaSuccess,
|
||||||
|
"topk kernel failed:", ::cudaGetErrorString(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fast_topk_transform_interface(at::Tensor score, at::Tensor lengths,
|
||||||
|
at::Tensor dst_page_table,
|
||||||
|
at::Tensor src_page_table,
|
||||||
|
at::Tensor cu_seqlens,
|
||||||
|
bool use_tilelang) -> void {
|
||||||
|
const auto params = get_params(score, lengths, use_tilelang);
|
||||||
|
const auto B = score.size(0);
|
||||||
|
TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous());
|
||||||
|
TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1);
|
||||||
|
TORCH_CHECK(cu_seqlens.dim() == 1 && cu_seqlens.is_contiguous());
|
||||||
|
const auto prefill_bs = cu_seqlens.size(0) - 1;
|
||||||
|
TORCH_CHECK(dst_page_table.size(0) == B);
|
||||||
|
TORCH_CHECK(dst_page_table.size(1) == TopK);
|
||||||
|
TORCH_CHECK(src_page_table.size(0) == prefill_bs);
|
||||||
|
TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs
|
||||||
|
|
||||||
|
// launch kernel
|
||||||
|
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
const auto grid = dim3{static_cast<uint32_t>(B)};
|
||||||
|
const auto block = dim3{kThreadsPerBlock};
|
||||||
|
const auto src_stride = src_page_table.stride(0);
|
||||||
|
|
||||||
|
// dispatch to decode or prefill
|
||||||
|
const auto is_decode = (prefill_bs == B);
|
||||||
|
if (is_decode) {
|
||||||
|
setup_kernel_smem_once<topk_kernel_transform_decode, kSmem>();
|
||||||
|
topk_kernel_transform_decode<<<grid, block, kSmem, stream>>>(
|
||||||
|
params, dst_page_table.data_ptr<int32_t>(),
|
||||||
|
src_page_table.data_ptr<int32_t>(), src_stride);
|
||||||
|
} else {
|
||||||
|
setup_kernel_smem_once<topk_kernel_transform_prefill, kSmem>();
|
||||||
|
topk_kernel_transform_prefill<<<grid, block, kSmem, stream>>>(
|
||||||
|
params, dst_page_table.data_ptr<int32_t>(),
|
||||||
|
src_page_table.data_ptr<int32_t>(), src_stride,
|
||||||
|
cu_seqlens.data_ptr<int32_t>(), prefill_bs);
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto result = cudaGetLastError();
|
||||||
|
TORCH_CHECK(result == cudaSuccess,
|
||||||
|
"topk kernel failed:", ::cudaGetErrorString(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
PYBIND11_MODULE(topk_kernel, m) {
|
||||||
|
m.def("fast_topk", &fast_topk_interface);
|
||||||
|
m.def("fast_topk_transform", &fast_topk_transform_interface);
|
||||||
|
}
|
||||||
39
python/sglang/srt/layers/attention/nsa/cuda/topk.py
Normal file
39
python/sglang/srt/layers/attention/nsa/cuda/topk.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .utils import load_kernel_module
|
||||||
|
|
||||||
|
|
||||||
|
def _load_topk_module() -> Any:
|
||||||
|
"""
|
||||||
|
Load the index manipulation module.
|
||||||
|
"""
|
||||||
|
return load_kernel_module("topk.cu", "topk_kernel")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(dark): configure out why my cuda impl is a little slower....
|
||||||
|
# I believe it has something to do with unrolling loops (?)
|
||||||
|
_USE_TL = True
|
||||||
|
|
||||||
|
|
||||||
|
def fast_topk(
|
||||||
|
score: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return _load_topk_module().fast_topk(score, indices, lengths, _USE_TL)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_topk_transform(
|
||||||
|
score: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
|
dst_page_table: torch.Tensor,
|
||||||
|
src_page_table: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return _load_topk_module().fast_topk_transform(
|
||||||
|
score, lengths, dst_page_table, src_page_table, cu_seqlens, _USE_TL
|
||||||
|
)
|
||||||
44
python/sglang/srt/layers/attention/nsa/cuda/utils.py
Normal file
44
python/sglang/srt/layers/attention/nsa/cuda/utils.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Iterable
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def _prepare_for_load() -> str:
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", category=UserWarning, module="torch.utils.cpp_extension"
|
||||||
|
)
|
||||||
|
return os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def load_kernel_module(
|
||||||
|
path: str | Iterable[str],
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
build: str = "build",
|
||||||
|
cflags: Iterable[str] | None = None,
|
||||||
|
cuda_flags: Iterable[str] | None = None,
|
||||||
|
ldflags: Iterable[str] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
from torch.utils.cpp_extension import load
|
||||||
|
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = (path,)
|
||||||
|
|
||||||
|
abs_path = _prepare_for_load()
|
||||||
|
build_dir = f"{abs_path}/{build}"
|
||||||
|
os.makedirs(build_dir, exist_ok=True)
|
||||||
|
return load(
|
||||||
|
name=name,
|
||||||
|
sources=[f"{abs_path}/csrc/{p}" for p in path],
|
||||||
|
extra_cflags=list(cflags or []) or ["-O3", "-std=c++17"],
|
||||||
|
extra_cuda_cflags=list(cuda_flags or []) or ["-O3", "-std=c++17"],
|
||||||
|
extra_ldflags=list(ldflags or []) or None,
|
||||||
|
build_directory=build_dir,
|
||||||
|
)
|
||||||
163
python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
Normal file
163
python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.nsa.utils import NSA_DEQUANT_K_CACHE_FAST
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize_k_cache(quant_k_cache):
|
||||||
|
if NSA_DEQUANT_K_CACHE_FAST:
|
||||||
|
return _dequantize_k_cache_fast_wrapped(quant_k_cache)
|
||||||
|
else:
|
||||||
|
return _dequantize_k_cache_slow(quant_k_cache)
|
||||||
|
|
||||||
|
|
||||||
|
def _dequantize_k_cache_slow(
|
||||||
|
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
|
||||||
|
dv: int = 512,
|
||||||
|
tile_size: int = 128,
|
||||||
|
d: int = 576,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
De-quantize the k-cache
|
||||||
|
"""
|
||||||
|
assert dv % tile_size == 0
|
||||||
|
num_tiles = dv // tile_size
|
||||||
|
num_blocks, block_size, h_k, _ = quant_k_cache.shape
|
||||||
|
assert h_k == 1
|
||||||
|
result = torch.empty(
|
||||||
|
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
|
||||||
|
)
|
||||||
|
|
||||||
|
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
|
||||||
|
|
||||||
|
input_nope = quant_k_cache[..., :dv]
|
||||||
|
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||||
|
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
|
||||||
|
result[..., dv:] = input_rope
|
||||||
|
|
||||||
|
for tile_idx in range(0, num_tiles):
|
||||||
|
cur_nope = input_nope[
|
||||||
|
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||||
|
].to(torch.float32)
|
||||||
|
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
|
||||||
|
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||||
|
cur_nope * cur_scales
|
||||||
|
)
|
||||||
|
|
||||||
|
result = result.view(num_blocks, block_size, 1, d)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _dequantize_k_cache_fast_wrapped(
|
||||||
|
quant_k_cache: torch.Tensor,
|
||||||
|
dv: int = 512,
|
||||||
|
tile_size: int = 128,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO the final API may be 2D instead of 4D, thus we convert them here
|
||||||
|
num_blocks, block_size, _, dim_quant = quant_k_cache.shape
|
||||||
|
assert dv == 512
|
||||||
|
assert dim_quant == 656
|
||||||
|
assert tile_size == 128
|
||||||
|
quant_k_cache = quant_k_cache.view((-1, dim_quant))
|
||||||
|
|
||||||
|
output = _dequantize_k_cache_fast(quant_k_cache)
|
||||||
|
|
||||||
|
return output.view(num_blocks, block_size, 1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
|
||||||
|
num_tokens, dim_quant = quant_k_cache.shape
|
||||||
|
|
||||||
|
assert quant_k_cache.dtype == torch.float8_e4m3fn
|
||||||
|
dim_nope = 512
|
||||||
|
dim_rope = 64
|
||||||
|
num_tiles = dim_nope // group_size
|
||||||
|
assert dim_quant == 656
|
||||||
|
|
||||||
|
output = torch.empty(
|
||||||
|
(num_tokens, dim_nope + dim_rope),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=quant_k_cache.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
|
||||||
|
assert num_blocks_per_token == 5
|
||||||
|
|
||||||
|
assert dim_nope % group_size == 0
|
||||||
|
NUM_NOPE_BLOCKS = dim_nope // group_size
|
||||||
|
|
||||||
|
input_nope_q = quant_k_cache[:, :dim_nope]
|
||||||
|
input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
|
||||||
|
torch.float32
|
||||||
|
)
|
||||||
|
input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
|
||||||
|
|
||||||
|
_dequantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
|
||||||
|
output,
|
||||||
|
input_nope_q,
|
||||||
|
input_nope_s,
|
||||||
|
input_rope,
|
||||||
|
output.stride(0),
|
||||||
|
input_nope_q.stride(0),
|
||||||
|
input_nope_s.stride(0),
|
||||||
|
input_rope.stride(0),
|
||||||
|
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
|
||||||
|
GROUP_SIZE=group_size,
|
||||||
|
DIM_NOPE=dim_nope,
|
||||||
|
DIM_ROPE=dim_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _dequantize_k_cache_fast_kernel(
|
||||||
|
output_ptr,
|
||||||
|
input_nope_q_ptr,
|
||||||
|
input_nope_s_ptr,
|
||||||
|
input_rope_ptr,
|
||||||
|
output_stride_0: int,
|
||||||
|
input_nope_q_stride_0: int,
|
||||||
|
input_nope_s_stride_0: int,
|
||||||
|
input_rope_stride_0: int,
|
||||||
|
NUM_NOPE_BLOCKS: tl.constexpr,
|
||||||
|
GROUP_SIZE: tl.constexpr,
|
||||||
|
DIM_NOPE: tl.constexpr,
|
||||||
|
DIM_ROPE: tl.constexpr,
|
||||||
|
):
|
||||||
|
token_id = tl.program_id(0)
|
||||||
|
raw_block_id = tl.program_id(1)
|
||||||
|
|
||||||
|
if raw_block_id < NUM_NOPE_BLOCKS:
|
||||||
|
# a. dequant nope
|
||||||
|
effective_block_id = raw_block_id
|
||||||
|
|
||||||
|
offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||||
|
mask = offs_q < DIM_NOPE
|
||||||
|
ptr_q = input_nope_q_ptr + token_id * input_nope_q_stride_0 + offs_q
|
||||||
|
ptr_s = input_nope_s_ptr + token_id * input_nope_s_stride_0 + effective_block_id
|
||||||
|
|
||||||
|
y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
y_s = tl.load(ptr_s)
|
||||||
|
|
||||||
|
y = (y_q * y_s).to(output_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
|
||||||
|
tl.store(dst_ptr, y, mask=mask)
|
||||||
|
else:
|
||||||
|
# b. copy rope
|
||||||
|
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
|
||||||
|
|
||||||
|
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||||
|
mask = offs < DIM_ROPE
|
||||||
|
|
||||||
|
src_ptr = input_rope_ptr + token_id * input_rope_stride_0 + offs
|
||||||
|
dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
|
||||||
|
|
||||||
|
data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
|
||||||
|
tl.store(dst_ptr, data, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise Exception("UT is in quant_k_cache.py")
|
||||||
135
python/sglang/srt/layers/attention/nsa/fallback_fp8.py
Normal file
135
python/sglang/srt/layers/attention/nsa/fallback_fp8.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
# fallback_fp8.py
|
||||||
|
# PyTorch fallback implementation for DeepGEMM-like fp8 logits ops
|
||||||
|
from sglang.srt.utils import ceil_div
|
||||||
|
import torch
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fallback_fp8_mqa_logits(q: torch.Tensor,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
ks: torch.Tensor,
|
||||||
|
ke: torch.Tensor, cost_only: bool = False) -> torch.Tensor:
|
||||||
|
seq_len_kv = kv.shape[0]
|
||||||
|
|
||||||
|
if cost_only:
|
||||||
|
start = ks.clamp(min=0, max=seq_len_kv)
|
||||||
|
end = ke.clamp(min=0, max=seq_len_kv)
|
||||||
|
count_ones_per_row = (end - start).clamp(min=0)
|
||||||
|
return count_ones_per_row.sum()
|
||||||
|
|
||||||
|
k = kv
|
||||||
|
q = q.float()
|
||||||
|
k = k.float()
|
||||||
|
|
||||||
|
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= ks[:, None]
|
||||||
|
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < ke[:, None]
|
||||||
|
mask = mask_lo & mask_hi
|
||||||
|
|
||||||
|
score = torch.einsum('mhd,nd->hmn', q, k)
|
||||||
|
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||||
|
logits = logits.masked_fill(~mask, float('-inf'))
|
||||||
|
|
||||||
|
#cost = mask.sum()
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# """
|
||||||
|
# PyTorch fallback for fp8_mqa_logits.
|
||||||
|
# No real fp8 used, just FP32.
|
||||||
|
# Args:
|
||||||
|
# q: (M, H, D) query
|
||||||
|
# k: (N, D) key
|
||||||
|
# weights: (M, H)
|
||||||
|
# ks: (M,) int32
|
||||||
|
# ke: (M,) int32
|
||||||
|
# Returns:
|
||||||
|
# logits: (M, N) with -inf outside of valid range
|
||||||
|
# """
|
||||||
|
# M, H, D = q.shape
|
||||||
|
# N = k[0].shape[0]
|
||||||
|
# logits = torch.full((M, N), float("-inf"), dtype=torch.float32, device=q.device)
|
||||||
|
|
||||||
|
# # for i in range(M):
|
||||||
|
# # start = max(ks[i].item(), 0)
|
||||||
|
# # end = min(ke[i].item(), N)
|
||||||
|
# # if start >= end:
|
||||||
|
# # continue
|
||||||
|
# # qi = q[i] # (H, D)
|
||||||
|
# # ki = k[start:end] # (L, D)
|
||||||
|
# # sim = torch.matmul(qi, ki.T) # (H, L)
|
||||||
|
# # weighted_sim = (sim.relu() * weights[i].unsqueeze(-1)).sum(dim=0) # (L,)
|
||||||
|
# # logits[i, start:end] = weighted_sim
|
||||||
|
# return logits
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fallback_fp8_paged_mqa_logits(q: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
max_model_len: int) -> torch.Tensor:
|
||||||
|
|
||||||
|
batch_size, next_n, heads, dim = q.size()
|
||||||
|
num_block, block_size, _, dim = kv_cache.size()
|
||||||
|
logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32)
|
||||||
|
context_lens = context_lens.tolist()
|
||||||
|
for i in range(batch_size):
|
||||||
|
context_len = context_lens[i]
|
||||||
|
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
|
||||||
|
weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||||
|
for block_rk in range(ceil_div(context_len, block_size)):
|
||||||
|
block_idx = block_tables[i][block_rk]
|
||||||
|
qx, kx = q[i], kv_cache[block_idx]
|
||||||
|
k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device=q.device)
|
||||||
|
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None])
|
||||||
|
s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf'))
|
||||||
|
s = torch.relu(s) * weight_slice[..., None]
|
||||||
|
s = s.sum(dim=0)
|
||||||
|
logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf'))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
PyTorch fallback for fp8_paged_mqa_logits.
|
||||||
|
No real fp8 used, just FP32.
|
||||||
|
Args:
|
||||||
|
q: (B, N, H, D)
|
||||||
|
kv_cache: (num_blocks, block_size, 1, D)
|
||||||
|
weights: (B * N, H)
|
||||||
|
context_lens: (B,)
|
||||||
|
block_tables: (B, max_blocks)
|
||||||
|
max_model_len: int
|
||||||
|
Returns:
|
||||||
|
logits: (B * N, max_model_len)
|
||||||
|
"""
|
||||||
|
B, N, H, D = q.shape
|
||||||
|
block_size = kv_cache.shape[1]
|
||||||
|
logits = torch.full((B * N, max_model_len), float("-inf"), dtype=torch.float32, device=q.device)
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
ctx_len = context_lens[i].item()
|
||||||
|
q_offsets = torch.arange(ctx_len - N, ctx_len, device=q.device)
|
||||||
|
weight_slice = weights[i * N:(i + 1) * N, :].transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
for br in range((ctx_len + block_size - 1) // block_size):
|
||||||
|
blk_idx = block_tables[i, br].item()
|
||||||
|
if blk_idx < 0:
|
||||||
|
continue
|
||||||
|
qx = q[i] # (N, H, D)
|
||||||
|
kx = kv_cache[blk_idx] # (block_size, 1, D)
|
||||||
|
kx = kx.squeeze(1) # (block_size, D)
|
||||||
|
k_offsets = torch.arange(br * block_size, (br + 1) * block_size, device=q.device)
|
||||||
|
|
||||||
|
mask = (k_offsets[None, :] < ctx_len) & (k_offsets[None, :] <= q_offsets[:, None]) # (N, block_size)
|
||||||
|
s = torch.where(mask[None, :, :],
|
||||||
|
torch.einsum('nhd,ld->hnl', qx, kx),
|
||||||
|
torch.full((H, N, block_size), float("-inf"), device=q.device))
|
||||||
|
s = s.relu() * weight_slice[..., None]
|
||||||
|
logits_slice = s.sum(dim=0) # (N, block_size)
|
||||||
|
|
||||||
|
mask_block = (k_offsets[None, :] <= q_offsets[:, None])
|
||||||
|
logits[i * N:(i + 1) * N, br * block_size:(br + 1) * block_size] = \
|
||||||
|
torch.where(mask_block, logits_slice, float("-inf"))
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
354
python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
Normal file
354
python/sglang/srt/layers/attention/nsa/index_buf_accessor.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||||
|
|
||||||
|
"""
|
||||||
|
k: data, 128 item per token, fp8
|
||||||
|
s: scale, 1 item per token, fp32
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class GetK:
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *args, **kwargs):
|
||||||
|
return cls.torch_fast(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slow(
|
||||||
|
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||||
|
):
|
||||||
|
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
|
||||||
|
seq_len_ = num_pages * pool.page_size
|
||||||
|
index_k_fp8 = torch.empty(
|
||||||
|
(seq_len_, pool.index_head_dim),
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=pool.device,
|
||||||
|
)
|
||||||
|
for i in range(num_pages):
|
||||||
|
page_index = page_indices[i]
|
||||||
|
index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
|
||||||
|
page_index
|
||||||
|
][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)
|
||||||
|
|
||||||
|
return index_k_fp8[:seq_len]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def torch_fast(
|
||||||
|
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param page_indices: (num_pages,), int32
|
||||||
|
:return: (seq_len, index_head_dim), uint8
|
||||||
|
"""
|
||||||
|
|
||||||
|
# can handle per 128B instead of per element
|
||||||
|
|
||||||
|
# page_indices: (num_pages,), element := a page index
|
||||||
|
buf_numel_per_page = buf.shape[1]
|
||||||
|
|
||||||
|
num_k_bytes_per_page = pool.page_size * pool.index_head_dim
|
||||||
|
num_k_bytes_per_token = pool.index_head_dim
|
||||||
|
|
||||||
|
# buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
|
||||||
|
# flat_buf: (whatever,), uint8
|
||||||
|
flat_buf = buf.flatten()
|
||||||
|
|
||||||
|
# flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
|
||||||
|
flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(
|
||||||
|
num_k_bytes_per_page, dtype=torch.int32, device="cuda"
|
||||||
|
)[None, :]
|
||||||
|
flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]
|
||||||
|
|
||||||
|
out = flat_buf[flat_indices]
|
||||||
|
return out.view(-1, 128)
|
||||||
|
|
||||||
|
|
||||||
|
class GetS:
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *args, **kwargs):
|
||||||
|
return cls.torch_fast(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slow(
|
||||||
|
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||||
|
):
|
||||||
|
num_pages = (seq_len + pool.page_size - 1) // pool.page_size
|
||||||
|
seq_len_ = num_pages * pool.page_size
|
||||||
|
assert pool.index_head_dim // pool.quant_block_size == 1
|
||||||
|
index_k_scale_fp8 = torch.empty(
|
||||||
|
(seq_len_, 4),
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=pool.device,
|
||||||
|
)
|
||||||
|
for i in range(num_pages):
|
||||||
|
page_index = page_indices[i]
|
||||||
|
index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
|
||||||
|
page_index
|
||||||
|
][pool.page_size * pool.index_head_dim :].view(-1, 4)
|
||||||
|
return index_k_scale_fp8[:seq_len]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def torch_fast(
|
||||||
|
cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param page_indices: (num_pages,), int32
|
||||||
|
:return: (seq_len, index_head_dim // quant_block_size), uint8
|
||||||
|
"""
|
||||||
|
buf_numel_per_page = buf.shape[1]
|
||||||
|
|
||||||
|
num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim
|
||||||
|
num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4
|
||||||
|
s_offset_in_page = pool.page_size * pool.index_head_dim
|
||||||
|
|
||||||
|
flat_buf = buf.flatten()
|
||||||
|
flat_indices = (
|
||||||
|
(page_indices * buf_numel_per_page)[:, None]
|
||||||
|
+ torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[
|
||||||
|
None, :
|
||||||
|
]
|
||||||
|
+ s_offset_in_page
|
||||||
|
)
|
||||||
|
flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]
|
||||||
|
|
||||||
|
out = flat_buf[flat_indices]
|
||||||
|
return out.view(-1, 4)
|
||||||
|
|
||||||
|
|
||||||
|
class SetK:
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *args, buf, **kwargs):
|
||||||
|
return cls.torch_fast(*args, **kwargs, buf=buf)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slow(
|
||||||
|
cls,
|
||||||
|
pool: "NSATokenToKVPool",
|
||||||
|
buf: torch.Tensor,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
index_k: torch.Tensor,
|
||||||
|
):
|
||||||
|
for i in range(len(loc)):
|
||||||
|
page_index = loc[i] // pool.page_size
|
||||||
|
offset = loc[i] % pool.page_size
|
||||||
|
buf[
|
||||||
|
page_index,
|
||||||
|
offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,
|
||||||
|
] = index_k[i].view(torch.uint8)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def torch_fast(
|
||||||
|
cls,
|
||||||
|
pool: "NSATokenToKVPool",
|
||||||
|
buf: torch.Tensor,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
index_k: torch.Tensor,
|
||||||
|
):
|
||||||
|
(num_tokens_to_write,) = loc.shape
|
||||||
|
buf_numel_per_page = buf.shape[1]
|
||||||
|
num_k_bytes_per_token = pool.index_head_dim
|
||||||
|
|
||||||
|
# loc: (num_tokens_to_write,), int32, element := the token index to write to
|
||||||
|
loc_page_index = loc // pool.page_size
|
||||||
|
loc_token_offset_in_page = loc % pool.page_size
|
||||||
|
|
||||||
|
flat_buf = buf.flatten()
|
||||||
|
flat_indices = (
|
||||||
|
(loc_page_index * buf_numel_per_page)[:, None]
|
||||||
|
+ (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]
|
||||||
|
+ torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[
|
||||||
|
None, :
|
||||||
|
]
|
||||||
|
)
|
||||||
|
num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token
|
||||||
|
flat_indices = flat_indices.flatten()[:num_k_bytes_total]
|
||||||
|
flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()
|
||||||
|
|
||||||
|
|
||||||
|
class SetS:
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *args, buf, **kwargs):
|
||||||
|
return cls.torch_fast(*args, **kwargs, buf=buf)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slow(
|
||||||
|
cls,
|
||||||
|
pool: "NSATokenToKVPool",
|
||||||
|
buf: torch.Tensor,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
index_k_scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
for i in range(len(loc)):
|
||||||
|
page_index = loc[i] // pool.page_size
|
||||||
|
offset = loc[i] % pool.page_size
|
||||||
|
start = pool.page_size * pool.index_head_dim
|
||||||
|
buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (
|
||||||
|
index_k_scale[i].view(torch.uint8)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def torch_fast(
|
||||||
|
cls,
|
||||||
|
pool: "NSATokenToKVPool",
|
||||||
|
buf: torch.Tensor,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
index_k_scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
(num_tokens_to_write,) = loc.shape
|
||||||
|
buf_numel_per_page = buf.shape[1]
|
||||||
|
num_s_bytes_per_token = 4
|
||||||
|
s_offset_in_page = pool.page_size * pool.index_head_dim
|
||||||
|
|
||||||
|
# loc: (num_tokens_to_write,), int32, element := the token index to write to
|
||||||
|
loc_page_index = loc // pool.page_size
|
||||||
|
loc_token_offset_in_page = loc % pool.page_size
|
||||||
|
|
||||||
|
flat_buf = buf.flatten()
|
||||||
|
flat_indices = (
|
||||||
|
(loc_page_index * buf_numel_per_page)[:, None]
|
||||||
|
+ s_offset_in_page
|
||||||
|
+ (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]
|
||||||
|
+ torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[
|
||||||
|
None, :
|
||||||
|
]
|
||||||
|
)
|
||||||
|
number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token
|
||||||
|
flat_indices = flat_indices.flatten()[:number_s_bytes_total]
|
||||||
|
flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()
|
||||||
|
|
||||||
|
|
||||||
|
class SetKAndS:
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *args, buf, **kwargs):
|
||||||
|
if 0:
|
||||||
|
# print("SetK, SetS comparison test")
|
||||||
|
buf_cloned = buf.clone()
|
||||||
|
cls.vanilla(*args, **kwargs, buf=buf)
|
||||||
|
cls.triton(*args, **kwargs, buf=buf_cloned)
|
||||||
|
|
||||||
|
def _clear_token_0(target):
|
||||||
|
target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0
|
||||||
|
|
||||||
|
_clear_token_0(buf)
|
||||||
|
_clear_token_0(buf_cloned)
|
||||||
|
|
||||||
|
assert torch.all(
|
||||||
|
buf == buf_cloned
|
||||||
|
), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}"
|
||||||
|
return
|
||||||
|
|
||||||
|
cls.triton(*args, **kwargs, buf=buf)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def vanilla(cls, pool, buf, loc, index_k, index_k_scale):
|
||||||
|
SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)
|
||||||
|
SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def triton(cls, pool, buf, loc, index_k, index_k_scale):
|
||||||
|
_set_k_and_s_triton(
|
||||||
|
buf=buf,
|
||||||
|
loc=loc,
|
||||||
|
index_k=index_k,
|
||||||
|
index_k_scale=index_k_scale,
|
||||||
|
page_size=pool.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_k_and_s_triton(
|
||||||
|
buf: torch.Tensor,
|
||||||
|
loc: torch.Tensor,
|
||||||
|
index_k: torch.Tensor,
|
||||||
|
index_k_scale: torch.Tensor,
|
||||||
|
page_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
|
||||||
|
:param loc: (num_tokens_to_write,), int, element := the token index to write to
|
||||||
|
:param index_k: (num_tokens_to_write, 128 elem), fp8
|
||||||
|
:param index_k_scale: (num_tokens_to_write, 1 elem), fp32
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
num_pages, buf_numel_per_page = buf.shape
|
||||||
|
(num_tokens_to_write,) = loc.shape
|
||||||
|
num_tokens_to_write_, index_head_dim = index_k.shape
|
||||||
|
num_tokens_to_write__, scale_dim = index_k_scale.shape
|
||||||
|
assert buf_numel_per_page == 64 * (128 + 4)
|
||||||
|
assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
|
||||||
|
assert index_head_dim == 128
|
||||||
|
assert scale_dim == 1
|
||||||
|
assert page_size == 64
|
||||||
|
|
||||||
|
assert buf.dtype == torch.uint8
|
||||||
|
assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32
|
||||||
|
assert index_k.dtype == torch.float8_e4m3fn
|
||||||
|
assert index_k_scale.dtype == torch.float32
|
||||||
|
|
||||||
|
assert buf.is_contiguous()
|
||||||
|
assert loc.is_contiguous()
|
||||||
|
assert index_k.is_contiguous()
|
||||||
|
assert index_k_scale.is_contiguous()
|
||||||
|
|
||||||
|
buf_fp8 = buf.view(torch.float8_e4m3fn)
|
||||||
|
buf_fp32 = buf.view(torch.float32)
|
||||||
|
|
||||||
|
_set_k_and_s_triton_kernel[(num_tokens_to_write,)](
|
||||||
|
buf_fp8,
|
||||||
|
buf_fp32,
|
||||||
|
loc,
|
||||||
|
index_k,
|
||||||
|
index_k_scale,
|
||||||
|
index_k.stride(0),
|
||||||
|
PAGE_SIZE=page_size,
|
||||||
|
BUF_NUMEL_PER_PAGE=buf_numel_per_page,
|
||||||
|
NUM_K_ELEMS_PER_TOKEN=index_head_dim,
|
||||||
|
S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _set_k_and_s_triton_kernel(
|
||||||
|
buf_fp8_ptr,
|
||||||
|
buf_fp32_ptr,
|
||||||
|
loc_ptr,
|
||||||
|
index_k_ptr,
|
||||||
|
index_k_scale_ptr,
|
||||||
|
index_k_ptr_stride_0,
|
||||||
|
PAGE_SIZE: tl.constexpr,
|
||||||
|
BUF_NUMEL_PER_PAGE: tl.constexpr,
|
||||||
|
NUM_K_ELEMS_PER_TOKEN: tl.constexpr,
|
||||||
|
S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,
|
||||||
|
):
|
||||||
|
token_id = tl.program_id(0)
|
||||||
|
|
||||||
|
loc = tl.load(loc_ptr + token_id)
|
||||||
|
|
||||||
|
in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
|
||||||
|
|
||||||
|
# no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
|
||||||
|
k = tl.load(index_k_ptr + in_k_offsets)
|
||||||
|
k_scale = tl.load(index_k_scale_ptr + token_id)
|
||||||
|
|
||||||
|
loc_page_index = loc // PAGE_SIZE
|
||||||
|
loc_token_offset_in_page = loc % PAGE_SIZE
|
||||||
|
|
||||||
|
out_k_offsets = (
|
||||||
|
loc_page_index * BUF_NUMEL_PER_PAGE
|
||||||
|
+ loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN
|
||||||
|
+ tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
|
||||||
|
)
|
||||||
|
|
||||||
|
# "//4" b/c it is fp32 instead of uint8
|
||||||
|
out_s_offset = (
|
||||||
|
loc_page_index * BUF_NUMEL_PER_PAGE // 4
|
||||||
|
+ S_OFFSET_NBYTES_IN_PAGE // 4
|
||||||
|
+ loc_token_offset_in_page
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(buf_fp8_ptr + out_k_offsets, k)
|
||||||
|
tl.store(buf_fp32_ptr + out_s_offset, k_scale)
|
||||||
709
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Normal file
709
python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Normal file
@@ -0,0 +1,709 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.nsa.fallback_fp8 import fallback_fp8_mqa_logits, fallback_fp8_paged_mqa_logits
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from sglang.srt.custom_op import CustomOp
|
||||||
|
from sglang.srt.debug_utils.dumper import dumper
|
||||||
|
from sglang.srt.utils import add_prefix, is_npu
|
||||||
|
|
||||||
|
if not is_npu():
|
||||||
|
from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
|
||||||
|
#import deep_gemm
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_group
|
||||||
|
from sglang.srt.layers.linear import ReplicatedLinear
|
||||||
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
from sglang.srt.utils import add_prefix, align, is_cuda
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# import deep_gemm_v32
|
||||||
|
# except ImportError as e:
|
||||||
|
# print("Error when importing deep_gemm_v32, try deep_gemm")
|
||||||
|
# try:
|
||||||
|
# import deep_gemm as deep_gemm_v32
|
||||||
|
# except ImportError as e:
|
||||||
|
# print("Error when importing deep_gemm, skip")
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||||
|
|
||||||
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
|
||||||
|
|
||||||
|
|
||||||
|
class BaseIndexerMetadata(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_seqlens_int32(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Return: (batch_size,) int32 tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_page_table_64(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Return: (batch_size, num_blocks) int32, page table.
|
||||||
|
The page size of the table is 64.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_seqlens_expanded(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Return: (sum_extend_seq_len,) int32 tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def topk_transform(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Perform topk selection on the logits and possibly transform the result.
|
||||||
|
|
||||||
|
NOTE that attention backend may override this function to do some
|
||||||
|
transformation, which means the result of this topk_transform may not
|
||||||
|
be the topk indices of the input logits.
|
||||||
|
|
||||||
|
Return: Anything, since it will be passed to the attention backend
|
||||||
|
for further processing on sparse attention computation.
|
||||||
|
Don't assume it is the topk indices of the input logits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def hadamard_transform_pytorch(x: torch.Tensor, scale: float) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
A native PyTorch implementation of the Fast Hadamard Transform that mimics
|
||||||
|
the behavior of the custom CUDA kernel's call signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (*, N), where N is a power of 2.
|
||||||
|
scale (float): The normalization factor to multiply the result by.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The Hadamard transformed tensor.
|
||||||
|
"""
|
||||||
|
# Base case for recursion
|
||||||
|
if x.shape[-1] == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Split the tensor into two halves
|
||||||
|
half_size = x.shape[-1] // 2
|
||||||
|
a = x[..., :half_size]
|
||||||
|
b = x[..., half_size:]
|
||||||
|
|
||||||
|
# Recursive calls
|
||||||
|
a_transformed = hadamard_transform_pytorch(a, scale=1.0) # No scaling in intermediate steps
|
||||||
|
b_transformed = hadamard_transform_pytorch(b, scale=1.0) # No scaling in intermediate steps
|
||||||
|
|
||||||
|
# Combine the results
|
||||||
|
combined = torch.cat([a_transformed + b_transformed, a_transformed - b_transformed], dim=-1)
|
||||||
|
|
||||||
|
# Apply the scale only at the final step
|
||||||
|
return combined * scale
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert x.dtype == torch.bfloat16
|
||||||
|
#from fast_hadamard_transform import hadamard_transform
|
||||||
|
|
||||||
|
hidden_size = x.size(-1)
|
||||||
|
assert (
|
||||||
|
hidden_size & (hidden_size - 1)
|
||||||
|
) == 0, "Hidden size must be a power of 2 for Hadamard transform."
|
||||||
|
return hadamard_transform_pytorch(x, scale=hidden_size**-0.5)
|
||||||
|
|
||||||
|
|
||||||
|
class V32LayerNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
Layer Normalization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return F.layer_norm(
|
||||||
|
x.float(), (self.dim,), self.weight, self.bias, self.eps
|
||||||
|
).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Indexer(CustomOp):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
index_n_heads: int,
|
||||||
|
index_head_dim: int,
|
||||||
|
rope_head_dim: int,
|
||||||
|
index_topk: int,
|
||||||
|
q_lora_rank: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
rope_theta: float,
|
||||||
|
layer_id: int,
|
||||||
|
scale_fmt: Optional[str],
|
||||||
|
block_size: int = 128,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.n_heads = index_n_heads
|
||||||
|
self.head_dim = index_head_dim
|
||||||
|
self.rope_head_dim = rope_head_dim
|
||||||
|
self.index_topk = index_topk
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.alt_stream = alt_stream
|
||||||
|
if not is_npu():
|
||||||
|
self.sm_count = torch.cuda.get_device_properties(0).multi_processor_count
|
||||||
|
self.half_device_sm_count = align(self.sm_count // 2, 8)
|
||||||
|
|
||||||
|
self.wq_b = ReplicatedLinear(
|
||||||
|
self.q_lora_rank,
|
||||||
|
self.n_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("wq_b", prefix),
|
||||||
|
)
|
||||||
|
self.wk = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("wk", prefix),
|
||||||
|
)
|
||||||
|
self.k_norm = V32LayerNorm(self.head_dim)
|
||||||
|
# NOTE: weight_proj is not quantized
|
||||||
|
self.weights_proj = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.n_heads,
|
||||||
|
bias=False,
|
||||||
|
prefix=add_prefix("weights_proj", prefix),
|
||||||
|
)
|
||||||
|
self.rotary_emb = get_rope_wrapper(
|
||||||
|
rope_head_dim,
|
||||||
|
rotary_dim=rope_head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta, # type: ignore
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=False,
|
||||||
|
device=global_server_args_dict["device"],
|
||||||
|
)
|
||||||
|
self.block_size = block_size
|
||||||
|
self.scale_fmt = scale_fmt
|
||||||
|
self.softmax_scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
def _forward_fake(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
q_lora: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
layer_id: int,
|
||||||
|
):
|
||||||
|
bs = x.shape[0]
|
||||||
|
assert self.index_topk == 2048
|
||||||
|
ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
|
||||||
|
None, ...
|
||||||
|
].repeat(bs, 1)
|
||||||
|
if forward_batch.forward_mode.is_extend():
|
||||||
|
assert (
|
||||||
|
forward_batch.extend_seq_lens_cpu is not None
|
||||||
|
and forward_batch.seq_lens_cpu is not None
|
||||||
|
)
|
||||||
|
which = 0
|
||||||
|
for i, (kv_len, qo_len) in enumerate(
|
||||||
|
zip(
|
||||||
|
forward_batch.seq_lens_cpu.tolist(),
|
||||||
|
forward_batch.extend_seq_lens_cpu,
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
for j in range(kv_len - qo_len, kv_len):
|
||||||
|
ans[which, j + 1 :] = -1
|
||||||
|
which += 1
|
||||||
|
assert which == ans.shape[0]
|
||||||
|
else:
|
||||||
|
assert forward_batch.seq_lens_cpu is not None
|
||||||
|
for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
|
||||||
|
ans[i, seq_len:] = -1
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||||
|
weights, _ = self.weights_proj(x)
|
||||||
|
weights = weights * self.n_heads**-0.5
|
||||||
|
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def _get_q_k_bf16(
|
||||||
|
self,
|
||||||
|
q_lora: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
enable_dual_stream: bool,
|
||||||
|
):
|
||||||
|
|
||||||
|
if enable_dual_stream:
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
|
||||||
|
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||||
|
self.half_device_sm_count
|
||||||
|
):
|
||||||
|
query, _ = self.wq_b(q_lora)
|
||||||
|
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
|
||||||
|
q_rope, _ = torch.split(
|
||||||
|
query,
|
||||||
|
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
key, _ = self.wk(x)
|
||||||
|
key = self.k_norm(key)
|
||||||
|
|
||||||
|
k_rope, _ = torch.split(
|
||||||
|
key,
|
||||||
|
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_stream.wait_stream(self.alt_stream)
|
||||||
|
else:
|
||||||
|
query, _ = self.wq_b(q_lora)
|
||||||
|
if dumper._enable:
|
||||||
|
after_wq_b = query.clone()
|
||||||
|
query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
|
||||||
|
|
||||||
|
q_rope, _ = torch.split(
|
||||||
|
query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
key, _ = self.wk(x)
|
||||||
|
if dumper._enable:
|
||||||
|
after_wk = key.clone()
|
||||||
|
key = self.k_norm(key)
|
||||||
|
if dumper._enable:
|
||||||
|
after_k_norm = key.clone()
|
||||||
|
k_rope, _ = torch.split(
|
||||||
|
key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
|
q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
|
||||||
|
query[..., : self.rope_head_dim] = q_rope
|
||||||
|
key[..., : self.rope_head_dim] = k_rope
|
||||||
|
|
||||||
|
if dumper._enable:
|
||||||
|
q_before_hadamard = query.clone()
|
||||||
|
k_before_hadamard = key.clone()
|
||||||
|
|
||||||
|
if enable_dual_stream:
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
query = rotate_activation(query)
|
||||||
|
|
||||||
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
key = rotate_activation(key)
|
||||||
|
current_stream.wait_stream(self.alt_stream)
|
||||||
|
else:
|
||||||
|
query = rotate_activation(query)
|
||||||
|
key = rotate_activation(key)
|
||||||
|
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
def _get_topk_paged(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
layer_id: int,
|
||||||
|
q_fp8: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
metadata: BaseIndexerMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
||||||
|
|
||||||
|
page_size = forward_batch.token_to_kv_pool.page_size
|
||||||
|
# NOTE(dark): blocksize = 64 is hardcoded in deep_gemm_v32
|
||||||
|
assert page_size == 64, "only support page size 64"
|
||||||
|
|
||||||
|
# NOTE(dark): this support extend/decode/decode+graph
|
||||||
|
block_tables = metadata.get_page_table_64()
|
||||||
|
|
||||||
|
max_seq_len = block_tables.shape[1] * page_size
|
||||||
|
kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
|
||||||
|
layer_id=layer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
blocksize = page_size
|
||||||
|
seqlens_32 = metadata.get_seqlens_int32()
|
||||||
|
# NOTE(dark): 132 is SM count on H200/B200, not magic number
|
||||||
|
# schedule_metadata = deep_gemm_v32.get_paged_mqa_logits_metadata(
|
||||||
|
# seqlens_32, blocksize, self.sm_count
|
||||||
|
# )
|
||||||
|
|
||||||
|
assert len(q_fp8.shape) == 3
|
||||||
|
q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
|
||||||
|
assert len(kv_cache_fp8.shape) == 2
|
||||||
|
block_kv = 64
|
||||||
|
num_heads_kv = 1
|
||||||
|
head_dim_with_sf = 132
|
||||||
|
kv_cache_fp8 = kv_cache_fp8.view(
|
||||||
|
kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
|
||||||
|
)
|
||||||
|
assert len(weights.shape) == 3
|
||||||
|
weights = weights.squeeze(2)
|
||||||
|
|
||||||
|
logits = fallback_fp8_paged_mqa_logits(
|
||||||
|
q_fp8,
|
||||||
|
kv_cache_fp8,
|
||||||
|
weights,
|
||||||
|
seqlens_32,
|
||||||
|
block_tables,
|
||||||
|
max_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(dark): logits should be cleaned in topk_transform
|
||||||
|
topk_result = metadata.topk_transform(logits, self.index_topk)
|
||||||
|
return topk_result
|
||||||
|
|
||||||
|
def _get_topk_ragged(
|
||||||
|
self,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
layer_id: int,
|
||||||
|
q_fp8: torch.Tensor,
|
||||||
|
weights: torch.Tensor,
|
||||||
|
metadata: BaseIndexerMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
||||||
|
|
||||||
|
page_size = forward_batch.token_to_kv_pool.page_size
|
||||||
|
assert page_size == 64, "only support page size 64"
|
||||||
|
assert len(weights.shape) == 3
|
||||||
|
weights = weights.squeeze(-1)
|
||||||
|
k_fp8_list = []
|
||||||
|
k_scale_list = []
|
||||||
|
ks_list = []
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
block_tables = metadata.get_page_table_64()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
forward_batch.seq_lens_cpu is not None
|
||||||
|
and forward_batch.extend_seq_lens_cpu is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(forward_batch.batch_size):
|
||||||
|
seq_len = forward_batch.seq_lens_cpu[i].item()
|
||||||
|
assert isinstance(seq_len, int)
|
||||||
|
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
|
||||||
|
layer_id,
|
||||||
|
seq_len,
|
||||||
|
block_tables[i],
|
||||||
|
)
|
||||||
|
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
|
||||||
|
layer_id,
|
||||||
|
seq_len,
|
||||||
|
block_tables[i],
|
||||||
|
)
|
||||||
|
extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
|
||||||
|
ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
|
||||||
|
k_fp8_list.append(k_fp8)
|
||||||
|
k_scale_list.append(k_scale)
|
||||||
|
ks_list.append(ks)
|
||||||
|
offset += extend_seq_len
|
||||||
|
|
||||||
|
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
|
||||||
|
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
|
||||||
|
kv_fp8 = (k_fp8, k_scale)
|
||||||
|
ks = torch.cat(ks_list, dim=0)
|
||||||
|
seq_lens_expanded = metadata.get_seqlens_expanded()
|
||||||
|
ke = ks + seq_lens_expanded
|
||||||
|
|
||||||
|
logits = fallback_fp8_mqa_logits(
|
||||||
|
q_fp8,
|
||||||
|
k_fp8,
|
||||||
|
weights,
|
||||||
|
ks,
|
||||||
|
ke
|
||||||
|
)
|
||||||
|
|
||||||
|
assert logits.shape[0] == len(seq_lens_expanded)
|
||||||
|
topk_result = metadata.topk_transform(logits, self.index_topk)
|
||||||
|
|
||||||
|
return topk_result
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
q_lora: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
|
||||||
|
|
||||||
|
metadata = forward_batch.attn_backend.get_indexer_metadata(
|
||||||
|
layer_id, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
enable_dual_stream = (
|
||||||
|
NSA_DUAL_STREAM
|
||||||
|
and self.alt_stream is not None
|
||||||
|
and get_is_capture_mode()
|
||||||
|
and q_lora.shape[0] > 0
|
||||||
|
and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||||
|
)
|
||||||
|
|
||||||
|
# skip NSA if attention backend choose to skip this batch
|
||||||
|
if metadata is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not NSA_USE_REAL_INDEXER: # temporary
|
||||||
|
return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
|
||||||
|
|
||||||
|
query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
|
||||||
|
|
||||||
|
q_fp8 = query.to(torch.float32)
|
||||||
|
k_fp8 = key.to(torch.float32)
|
||||||
|
q_scale = torch.ones((query.shape[0], 1), dtype=torch.float32, device="cuda")
|
||||||
|
k_scale = torch.ones((key.shape[0], 1), dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
if enable_dual_stream:
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
|
||||||
|
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
|
||||||
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
|
||||||
|
current_stream.wait_stream(self.alt_stream)
|
||||||
|
else:
|
||||||
|
q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
|
||||||
|
k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
|
||||||
|
|
||||||
|
# k_fp8: (seq_len, head_dim) fp8_e4m3fn
|
||||||
|
# k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
|
||||||
|
# k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
|
||||||
|
# k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
|
||||||
|
forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
|
||||||
|
layer_id=layer_id,
|
||||||
|
loc=forward_batch.out_cache_loc,
|
||||||
|
index_k=k_fp8,
|
||||||
|
index_k_scale=k_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = self._get_logits_head_gate(x, q_scale)
|
||||||
|
|
||||||
|
assert forward_batch.seq_lens_cpu is not None
|
||||||
|
if len(forward_batch.seq_lens_cpu) == 0:
|
||||||
|
# this seems b/c max-pad, no worries?
|
||||||
|
# if x.shape[0] != 0:
|
||||||
|
# print(
|
||||||
|
# "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
|
||||||
|
# )
|
||||||
|
return torch.full(
|
||||||
|
(x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
|
topk_result = self._get_topk_paged(
|
||||||
|
forward_batch, layer_id, q_fp8, weights, metadata
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_result = self._get_topk_ragged(
|
||||||
|
forward_batch, layer_id, q_fp8, weights, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
return topk_result
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
q_lora: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
layer_id: int,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self._forward(x, q_lora, positions, forward_batch, layer_id)
|
||||||
|
|
||||||
|
def forward_npu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
q_lora: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
layer_id: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
import custom_ops
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
get_attention_tp_rank,
|
||||||
|
get_attention_tp_size,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
|
if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
|
||||||
|
actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
|
||||||
|
else:
|
||||||
|
actual_seq_lengths_kv = (
|
||||||
|
forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
|
||||||
|
)
|
||||||
|
enable_index_cp = (
|
||||||
|
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
|
||||||
|
)
|
||||||
|
is_prefill = forward_batch.forward_mode.is_extend()
|
||||||
|
|
||||||
|
attention_tp_rank = get_attention_tp_rank()
|
||||||
|
attention_tp_size = get_attention_tp_size()
|
||||||
|
|
||||||
|
cos_sin = self.rotary_emb.cos_sin_cache[positions]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
|
||||||
|
sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
|
||||||
|
if is_prefill and enable_index_cp:
|
||||||
|
slice_length = cos.shape[0] // attention_tp_size
|
||||||
|
cos = cos[
|
||||||
|
slice_length
|
||||||
|
* attention_tp_rank : slice_length
|
||||||
|
* (attention_tp_rank + 1)
|
||||||
|
]
|
||||||
|
sin = sin[
|
||||||
|
slice_length
|
||||||
|
* attention_tp_rank : slice_length
|
||||||
|
* (attention_tp_rank + 1)
|
||||||
|
]
|
||||||
|
|
||||||
|
slot_mapping = forward_batch.out_cache_loc
|
||||||
|
block_table = forward_batch.attn_backend.forward_metadata.block_tables
|
||||||
|
|
||||||
|
bs = x.shape[0]
|
||||||
|
|
||||||
|
q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
|
||||||
|
q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
|
||||||
|
q_pe, q_nope = torch.split(
|
||||||
|
q,
|
||||||
|
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||||
|
dim=-1,
|
||||||
|
) # [bs, 64, 64 + 64]
|
||||||
|
|
||||||
|
q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
|
||||||
|
q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
|
||||||
|
bs, self.n_heads, self.rope_head_dim
|
||||||
|
) # [bs, n, d]
|
||||||
|
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||||
|
|
||||||
|
k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
|
||||||
|
k = self.k_norm(k_proj)
|
||||||
|
k_pe, k_nope = torch.split(
|
||||||
|
k,
|
||||||
|
[self.rope_head_dim, self.head_dim - self.rope_head_dim],
|
||||||
|
dim=-1,
|
||||||
|
) # [bs, 64 + 64]
|
||||||
|
|
||||||
|
k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
|
||||||
|
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
|
||||||
|
bs, 1, self.rope_head_dim
|
||||||
|
) # [bs, 1, d]
|
||||||
|
k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
|
||||||
|
|
||||||
|
if is_prefill and enable_index_cp:
|
||||||
|
k, local_k = (
|
||||||
|
torch.empty(
|
||||||
|
(k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
|
||||||
|
dtype=k.dtype,
|
||||||
|
device=k.device,
|
||||||
|
),
|
||||||
|
k,
|
||||||
|
)
|
||||||
|
get_attention_tp_group().all_gather_into_tensor(k, local_k)
|
||||||
|
|
||||||
|
forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
|
||||||
|
|
||||||
|
indexer_input = {}
|
||||||
|
if is_prefill:
|
||||||
|
actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
|
||||||
|
actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
|
||||||
|
device=q.device
|
||||||
|
)
|
||||||
|
if enable_index_cp:
|
||||||
|
actual_seq_lengths_q -= bs * attention_tp_rank
|
||||||
|
actual_seq_lengths_q = torch.max(
|
||||||
|
actual_seq_lengths_q,
|
||||||
|
torch.zeros_like(actual_seq_lengths_q).to(
|
||||||
|
device=actual_seq_lengths_q.device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
actual_seq_lengths_q = torch.min(
|
||||||
|
actual_seq_lengths_q,
|
||||||
|
torch.full(actual_seq_lengths_q.shape, bs).to(
|
||||||
|
device=actual_seq_lengths_q.device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
|
||||||
|
actual_seq_lengths_q = torch.tensor(
|
||||||
|
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
actual_seq_lengths_q = (
|
||||||
|
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
|
||||||
|
|
||||||
|
x = x.view(-1, self.hidden_size)
|
||||||
|
weights = self.weights_proj(x)[0]
|
||||||
|
block_table = (
|
||||||
|
block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
|
||||||
|
)
|
||||||
|
|
||||||
|
topk_indices = torch.ops.custom.npu_lightning_indexer(
|
||||||
|
query=q.view(-1, self.n_heads, self.head_dim),
|
||||||
|
key=past_key_states,
|
||||||
|
weights=weights,
|
||||||
|
actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
|
||||||
|
actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
|
||||||
|
block_table=block_table,
|
||||||
|
layout_query="TND",
|
||||||
|
layout_key="PA_BSND",
|
||||||
|
sparse_count=self.index_topk,
|
||||||
|
sparse_mode=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_prefill and enable_index_cp:
|
||||||
|
topk_indices, local_topk_indices = (
|
||||||
|
torch.empty(
|
||||||
|
(
|
||||||
|
topk_indices.shape[0] * attention_tp_size,
|
||||||
|
topk_indices.shape[1],
|
||||||
|
topk_indices.shape[2],
|
||||||
|
),
|
||||||
|
dtype=topk_indices.dtype,
|
||||||
|
device=topk_indices.device,
|
||||||
|
),
|
||||||
|
topk_indices,
|
||||||
|
)
|
||||||
|
get_attention_tp_group().all_gather_into_tensor(
|
||||||
|
topk_indices, local_topk_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
return topk_indices
|
||||||
255
python/sglang/srt/layers/attention/nsa/quant_k_cache.py
Normal file
255
python/sglang/srt/layers/attention/nsa/quant_k_cache.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.nsa.utils import NSA_QUANT_K_CACHE_FAST
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_k_cache(cache_k):
|
||||||
|
# TODO upstream can skip concat([k_nope, k_pe]) since we split them here
|
||||||
|
if NSA_QUANT_K_CACHE_FAST:
|
||||||
|
return _quantize_k_cache_fast_wrapped(cache_k)
|
||||||
|
else:
|
||||||
|
return _quantize_k_cache_slow(cache_k)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from original
|
||||||
|
def _quantize_k_cache_slow(
|
||||||
|
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
|
||||||
|
dv: int = 512,
|
||||||
|
tile_size: int = 128,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Quantize the k-cache
|
||||||
|
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
|
||||||
|
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
|
||||||
|
"""
|
||||||
|
assert dv % tile_size == 0
|
||||||
|
num_tiles = dv // tile_size
|
||||||
|
num_blocks, block_size, h_k, d = input_k_cache.shape
|
||||||
|
assert h_k == 1
|
||||||
|
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
|
||||||
|
input_elem_size = input_k_cache.element_size()
|
||||||
|
|
||||||
|
result = torch.empty(
|
||||||
|
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
device=input_k_cache.device,
|
||||||
|
)
|
||||||
|
result_k_nope_part = result[..., :dv]
|
||||||
|
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||||
|
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
|
||||||
|
result_k_rope_part[:] = input_k_cache[..., dv:]
|
||||||
|
|
||||||
|
for tile_idx in range(0, num_tiles):
|
||||||
|
cur_scale_factors_inv = (
|
||||||
|
torch.abs(
|
||||||
|
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
|
||||||
|
)
|
||||||
|
.max(dim=-1)
|
||||||
|
.values
|
||||||
|
/ 448.0
|
||||||
|
) # [num_blocks, block_size]
|
||||||
|
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
|
||||||
|
|
||||||
|
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
|
||||||
|
cur_quantized_nope = (
|
||||||
|
input_k_cache[
|
||||||
|
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||||
|
].float()
|
||||||
|
/ cur_scale_factors_inv.float()
|
||||||
|
).to(torch.float8_e4m3fn)
|
||||||
|
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||||
|
cur_quantized_nope
|
||||||
|
)
|
||||||
|
|
||||||
|
result = result.view(num_blocks, block_size, 1, -1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _quantize_k_cache_fast_wrapped(
|
||||||
|
input_k_cache: torch.Tensor,
|
||||||
|
dv: int = 512,
|
||||||
|
tile_size: int = 128,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO the final API may be 2D instead of 4D, thus we convert them here
|
||||||
|
num_blocks, block_size, _, dim_nope_and_rope = input_k_cache.shape
|
||||||
|
assert dv == 512
|
||||||
|
assert dim_nope_and_rope == 512 + 64
|
||||||
|
assert tile_size == 128
|
||||||
|
input_k_cache = input_k_cache.view((-1, dim_nope_and_rope))
|
||||||
|
|
||||||
|
# TODO deliberately split into two tensors, then upstream can provide the two tensors instead of concat into one
|
||||||
|
k_nope = input_k_cache[:, :dv]
|
||||||
|
k_rope = input_k_cache[:, dv:]
|
||||||
|
|
||||||
|
output = _quantize_k_cache_fast(k_nope=k_nope, k_rope=k_rope)
|
||||||
|
|
||||||
|
return output.view(num_blocks, block_size, 1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128):
|
||||||
|
"""
|
||||||
|
:param k_nope: (num_tokens, dim_nope 512)
|
||||||
|
:param k_rope: (num_tokens, dim_rope 64)
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert k_nope.dtype == torch.bfloat16
|
||||||
|
assert k_rope.dtype == torch.bfloat16
|
||||||
|
|
||||||
|
num_tokens, dim_nope = k_nope.shape
|
||||||
|
num_tokens_, dim_rope = k_rope.shape
|
||||||
|
assert num_tokens == num_tokens_
|
||||||
|
assert dim_nope == 512
|
||||||
|
assert dim_rope == 64
|
||||||
|
assert k_nope.dtype == k_rope.dtype
|
||||||
|
num_tiles = dim_nope // group_size
|
||||||
|
|
||||||
|
assert k_nope.stride(1) == 1
|
||||||
|
assert k_rope.stride(1) == 1
|
||||||
|
|
||||||
|
output = torch.empty(
|
||||||
|
(num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope),
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
device=k_nope.device,
|
||||||
|
)
|
||||||
|
output_nope_q = output[..., :dim_nope]
|
||||||
|
output_nope_s = output[..., dim_nope : dim_nope + num_tiles * 4].view(torch.float32)
|
||||||
|
output_rope = output[..., dim_nope + num_tiles * 4 :].view(torch.bfloat16)
|
||||||
|
|
||||||
|
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
|
||||||
|
assert num_blocks_per_token == 5
|
||||||
|
|
||||||
|
assert dim_nope % group_size == 0
|
||||||
|
NUM_NOPE_BLOCKS = dim_nope // group_size
|
||||||
|
|
||||||
|
_quantize_k_cache_fast_kernel[(num_tokens, num_blocks_per_token)](
|
||||||
|
output_nope_q,
|
||||||
|
output_nope_s,
|
||||||
|
output_rope,
|
||||||
|
k_nope,
|
||||||
|
k_rope,
|
||||||
|
output_nope_q.stride(0),
|
||||||
|
output_nope_s.stride(0),
|
||||||
|
output_rope.stride(0),
|
||||||
|
k_nope.stride(0),
|
||||||
|
k_rope.stride(0),
|
||||||
|
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
|
||||||
|
GROUP_SIZE=group_size,
|
||||||
|
DIM_NOPE=dim_nope,
|
||||||
|
DIM_ROPE=dim_rope,
|
||||||
|
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
|
||||||
|
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _quantize_k_cache_fast_kernel(
|
||||||
|
output_nope_q_ptr,
|
||||||
|
output_nope_s_ptr,
|
||||||
|
output_rope_ptr,
|
||||||
|
k_nope_ptr,
|
||||||
|
k_rope_ptr,
|
||||||
|
output_nope_q_stride_0: int,
|
||||||
|
output_nope_s_stride_0: int,
|
||||||
|
output_rope_stride_0: int,
|
||||||
|
k_nope_stride_0: int,
|
||||||
|
k_rope_stride_0: int,
|
||||||
|
NUM_NOPE_BLOCKS: tl.constexpr,
|
||||||
|
GROUP_SIZE: tl.constexpr,
|
||||||
|
DIM_NOPE: tl.constexpr,
|
||||||
|
DIM_ROPE: tl.constexpr,
|
||||||
|
FP8_MIN: tl.constexpr,
|
||||||
|
FP8_MAX: tl.constexpr,
|
||||||
|
):
|
||||||
|
token_id = tl.program_id(0)
|
||||||
|
raw_block_id = tl.program_id(1)
|
||||||
|
|
||||||
|
if raw_block_id < NUM_NOPE_BLOCKS:
|
||||||
|
# a. quant nope
|
||||||
|
effective_block_id = raw_block_id
|
||||||
|
|
||||||
|
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||||
|
mask = offs < DIM_NOPE
|
||||||
|
ptr = k_nope_ptr + token_id * k_nope_stride_0 + offs
|
||||||
|
|
||||||
|
y = tl.load(ptr, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
|
||||||
|
# the ref impl do not have a `tl.maximum(... eps)`, so we remove it here
|
||||||
|
y_s = tl.max(tl.abs(y)) / FP8_MAX
|
||||||
|
y_s_inv = 1.0 / y_s
|
||||||
|
y_q = tl.clamp(y * y_s_inv, FP8_MIN, FP8_MAX).to(
|
||||||
|
output_nope_q_ptr.dtype.element_ty
|
||||||
|
)
|
||||||
|
|
||||||
|
dst_q_ptr = output_nope_q_ptr + token_id * output_nope_q_stride_0 + offs
|
||||||
|
dst_s_ptr = (
|
||||||
|
output_nope_s_ptr + token_id * output_nope_s_stride_0 + effective_block_id
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(dst_q_ptr, y_q, mask=mask)
|
||||||
|
tl.store(dst_s_ptr, y_s)
|
||||||
|
else:
|
||||||
|
# b. copy rope
|
||||||
|
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
|
||||||
|
|
||||||
|
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
|
||||||
|
mask = offs < DIM_ROPE
|
||||||
|
|
||||||
|
src_ptr = k_rope_ptr + token_id * k_rope_stride_0 + offs
|
||||||
|
dst_ptr = output_rope_ptr + token_id * output_rope_stride_0 + offs
|
||||||
|
|
||||||
|
data = tl.load(src_ptr, mask=mask)
|
||||||
|
tl.store(dst_ptr, data, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
for num_blocks, block_size in [
|
||||||
|
(1, 1),
|
||||||
|
(10, 64),
|
||||||
|
]:
|
||||||
|
dim_nope_and_rope = 512 + 64
|
||||||
|
|
||||||
|
input_k_cache = torch.randn(
|
||||||
|
(num_blocks, block_size, 1, dim_nope_and_rope),
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
# temp debug
|
||||||
|
# input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
|
||||||
|
|
||||||
|
ref_quant = _quantize_k_cache_slow(input_k_cache)
|
||||||
|
actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
|
||||||
|
# print(f"{input_k_cache=}")
|
||||||
|
# print(f"{ref_quant=}")
|
||||||
|
# print(f"{actual_quant=}")
|
||||||
|
# print(f"{ref_quant == actual_quant=}")
|
||||||
|
# print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
|
||||||
|
# print(f"{ref_quant.view(torch.bfloat16)=}")
|
||||||
|
# print(f"{actual_quant.view(torch.bfloat16)=}")
|
||||||
|
# assert torch.all(ref_quant == actual_quant)
|
||||||
|
|
||||||
|
import dequant_k_cache
|
||||||
|
|
||||||
|
ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
|
||||||
|
ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
|
||||||
|
actual_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(
|
||||||
|
actual_quant
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{ref_ref_dequant=}")
|
||||||
|
print(f"{actual_actual_dequant=}")
|
||||||
|
print(f"{actual_actual_dequant - ref_ref_dequant=}")
|
||||||
|
print(f"{torch.mean(ref_ref_dequant - actual_actual_dequant)=}")
|
||||||
|
|
||||||
|
# TODO too different?
|
||||||
|
torch.testing.assert_close(
|
||||||
|
ref_ref_dequant, ref_actual_dequant, atol=0.2, rtol=0.2
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Passed")
|
||||||
813
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
Normal file
813
python/sglang/srt/layers/attention/nsa/tilelang_kernel.py
Normal file
@@ -0,0 +1,813 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
# import tilelang
|
||||||
|
# import tilelang.language as T
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# tilelang.set_log_level("WARNING")
|
||||||
|
|
||||||
|
# pass_configs = {
|
||||||
|
# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||||
|
# tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||||
|
# tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
|
||||||
|
# }
|
||||||
|
|
||||||
|
BF16 = "bfloat16"
|
||||||
|
FP8 = "float8_e4m3"
|
||||||
|
FP32 = "float32"
|
||||||
|
|
||||||
|
'''
|
||||||
|
def fast_log2_ceil(x):
|
||||||
|
bits_x = T.reinterpret("uint32", x)
|
||||||
|
exp_x = (bits_x >> 23) & 0xFF
|
||||||
|
man_bits = bits_x & ((1 << 23) - 1)
|
||||||
|
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
|
||||||
|
|
||||||
|
|
||||||
|
def fast_pow2(x):
|
||||||
|
bits_x = (x + 127) << 23
|
||||||
|
return T.reinterpret("float32", bits_x)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_round_scale(amax, fp8_max_inv):
|
||||||
|
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
|
||||||
|
|
||||||
|
@tilelang.jit(pass_configs=pass_configs)
|
||||||
|
def act_quant_kernel(
|
||||||
|
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
|
||||||
|
):
|
||||||
|
M = T.symbolic("M")
|
||||||
|
fp8_min = -448.0
|
||||||
|
fp8_max = 448.0
|
||||||
|
fp8_max_inv = 1 / fp8_max
|
||||||
|
num_stages = 0 if round_scale else 2
|
||||||
|
blk_m = 32
|
||||||
|
group_size = 128
|
||||||
|
|
||||||
|
@T.prim_func
|
||||||
|
def act_quant_kernel_(
|
||||||
|
X: T.Tensor[(M, N), in_dtype],
|
||||||
|
Y: T.Tensor[(M, N), out_dtype],
|
||||||
|
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
|
||||||
|
):
|
||||||
|
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
|
||||||
|
pid_m,
|
||||||
|
pid_n,
|
||||||
|
):
|
||||||
|
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
|
||||||
|
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
|
||||||
|
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
|
||||||
|
s_local = T.alloc_fragment((blk_m,), scale_dtype)
|
||||||
|
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
|
||||||
|
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
|
||||||
|
|
||||||
|
for _ in T.Pipelined(1, num_stages=num_stages):
|
||||||
|
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
|
||||||
|
T.copy(x_shared, x_local)
|
||||||
|
T.reduce_absmax(x_local, amax_local, dim=1)
|
||||||
|
for i in T.Parallel(blk_m):
|
||||||
|
amax_local[i] = T.max(amax_local[i], 1e-4)
|
||||||
|
if round_scale:
|
||||||
|
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
|
||||||
|
else:
|
||||||
|
s_local[i] = amax_local[i] * fp8_max_inv
|
||||||
|
for i, j in T.Parallel(blk_m, group_size):
|
||||||
|
y_local[i, j] = T.clamp(
|
||||||
|
x_local[i, j] / s_local[i], fp8_min, fp8_max
|
||||||
|
)
|
||||||
|
for i in T.Parallel(blk_m):
|
||||||
|
S[pid_m * blk_m + i, pid_n] = s_local[i]
|
||||||
|
T.copy(y_local, y_shared)
|
||||||
|
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
|
||||||
|
|
||||||
|
return act_quant_kernel_
|
||||||
|
|
||||||
|
def act_quant(
|
||||||
|
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Quantizes the input tensor `x` using block-wise quantization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
||||||
|
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||||
|
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||||
|
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||||
|
- A tensor of scaling factors with dtype `torch.float32`.
|
||||||
|
"""
|
||||||
|
assert x.is_contiguous(), "Input tensor must be contiguous"
|
||||||
|
assert (
|
||||||
|
x.size(-1) % block_size == 0
|
||||||
|
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
|
||||||
|
N = x.size(-1)
|
||||||
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||||
|
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
|
||||||
|
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
|
||||||
|
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
|
||||||
|
return y, s
|
||||||
|
|
||||||
|
|
||||||
|
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
|
||||||
|
def fp8_index_kernel(h: int, d: int):
|
||||||
|
b = T.symbolic("b")
|
||||||
|
m = T.symbolic("m")
|
||||||
|
n = T.symbolic("n")
|
||||||
|
|
||||||
|
blk_n1 = 512
|
||||||
|
blk_n2 = 128
|
||||||
|
|
||||||
|
@T.prim_func
|
||||||
|
def fp8_index_kernel_(
|
||||||
|
q: T.Tensor[(b, m, h, d), FP8],
|
||||||
|
q_s: T.Tensor[(b, m, h), FP32],
|
||||||
|
k: T.Tensor[(b, n, d), FP8],
|
||||||
|
k_s: T.Tensor[(b, n), FP32],
|
||||||
|
o: T.Tensor[(b, m, n), FP32],
|
||||||
|
) -> None:
|
||||||
|
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
|
||||||
|
q_smem = T.alloc_shared((h, d), FP8)
|
||||||
|
T.copy(q[i_b, i_m, 0, 0], q_smem)
|
||||||
|
|
||||||
|
q_s_frag = T.alloc_fragment(h, FP32)
|
||||||
|
T.copy(q_s[i_b, i_m, 0], q_s_frag)
|
||||||
|
|
||||||
|
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
|
||||||
|
k_smem = T.alloc_shared((blk_n2, d), FP8)
|
||||||
|
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
|
||||||
|
|
||||||
|
k_s_frag = T.alloc_fragment(blk_n2, FP32)
|
||||||
|
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
|
||||||
|
|
||||||
|
logits = T.alloc_fragment((blk_n2, h), FP32)
|
||||||
|
T.gemm(
|
||||||
|
k_smem,
|
||||||
|
q_smem,
|
||||||
|
logits,
|
||||||
|
transpose_A=False,
|
||||||
|
transpose_B=True,
|
||||||
|
clear_accum=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i_h, i3_n in T.Parallel(h, blk_n2):
|
||||||
|
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
|
||||||
|
|
||||||
|
logits_sum = T.alloc_fragment(blk_n2, FP32)
|
||||||
|
T.reduce_sum(logits, logits_sum, dim=1)
|
||||||
|
|
||||||
|
for i3_n in T.Parallel(blk_n2):
|
||||||
|
logits_sum[i3_n] *= k_s_frag[i3_n]
|
||||||
|
|
||||||
|
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
|
||||||
|
|
||||||
|
return fp8_index_kernel_
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_index(
|
||||||
|
q: torch.Tensor,
|
||||||
|
q_s: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
k_s: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Perform index score using FP8 precision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (torch.Tensor): The Q tensor, must be contiguous.
|
||||||
|
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
|
||||||
|
k (torch.Tensor): The K tensor, must be contiguous.
|
||||||
|
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
|
||||||
|
|
||||||
|
fp8 q @ fp8 k -> fp32 logits
|
||||||
|
relu(fp32 logits) * q_s (weights) -> fp32 logits
|
||||||
|
fp32 logits -> fp32 logits_sum
|
||||||
|
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
|
||||||
|
"""
|
||||||
|
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
|
||||||
|
|
||||||
|
|
||||||
|
@tilelang.jit(
|
||||||
|
out_idx=[-1],
|
||||||
|
pass_configs={
|
||||||
|
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||||
|
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def sparse_attention_fwd_kernel_v1(
|
||||||
|
num_heads,
|
||||||
|
dim,
|
||||||
|
tail_dim,
|
||||||
|
topk,
|
||||||
|
*,
|
||||||
|
kv_group=1,
|
||||||
|
sm_scale=None,
|
||||||
|
is_causal=True,
|
||||||
|
block_I=64,
|
||||||
|
num_stages=2,
|
||||||
|
threads=256,
|
||||||
|
):
|
||||||
|
assert dim == tilelang.math.next_power_of_2(
|
||||||
|
dim
|
||||||
|
), f"haven't check padding correctness yet, dim={dim}"
|
||||||
|
assert tail_dim == tilelang.math.next_power_of_2(
|
||||||
|
tail_dim
|
||||||
|
), f"haven't check padding correctness yet, dim={tail_dim}"
|
||||||
|
assert is_causal == True, "non-casual is not supported"
|
||||||
|
assert (
|
||||||
|
topk % block_I == 0
|
||||||
|
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
|
||||||
|
else:
|
||||||
|
sm_scale = sm_scale * 1.44269504 # log2(e)
|
||||||
|
|
||||||
|
batch = T.symbolic("batch")
|
||||||
|
seq_len = T.symbolic("seq_len")
|
||||||
|
seq_len_kv = T.symbolic("seq_len_kv")
|
||||||
|
|
||||||
|
head_kv = num_heads // kv_group
|
||||||
|
q_shape = [batch, seq_len, num_heads, dim + tail_dim]
|
||||||
|
kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
|
||||||
|
o_shape = [batch, seq_len, num_heads, dim]
|
||||||
|
indices_shape = [batch, seq_len, kv_group, topk]
|
||||||
|
indices_dtype = "int32"
|
||||||
|
dtype = "bfloat16"
|
||||||
|
accum_dtype = "float"
|
||||||
|
|
||||||
|
H = head_kv
|
||||||
|
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
|
||||||
|
if padded_H != H:
|
||||||
|
assert kv_group == 1
|
||||||
|
BI = block_I
|
||||||
|
NI = tilelang.cdiv(topk, block_I)
|
||||||
|
D = dim
|
||||||
|
D_tail = tail_dim
|
||||||
|
|
||||||
|
if head_kv > 64:
|
||||||
|
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
|
||||||
|
REPLICATE_H = head_kv // 64
|
||||||
|
else:
|
||||||
|
REPLICATE_H = 1
|
||||||
|
|
||||||
|
H_per_block = padded_H if REPLICATE_H == 1 else 64
|
||||||
|
|
||||||
|
@T.prim_func
|
||||||
|
def main(
|
||||||
|
Q: T.Tensor(q_shape, dtype), # type: ignore
|
||||||
|
KV: T.Tensor(kv_shape, dtype), # type: ignore
|
||||||
|
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
|
||||||
|
Output: T.Tensor(o_shape, dtype), # type: ignore
|
||||||
|
):
|
||||||
|
with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
|
||||||
|
bx,
|
||||||
|
by,
|
||||||
|
bz,
|
||||||
|
):
|
||||||
|
Q_shared = T.alloc_shared([H_per_block, D], dtype)
|
||||||
|
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
|
||||||
|
KV_shared = T.alloc_shared([BI, D], dtype)
|
||||||
|
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
|
||||||
|
O_shared = T.alloc_shared([H_per_block, D], dtype)
|
||||||
|
mask = T.alloc_fragment([BI], "bool")
|
||||||
|
|
||||||
|
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
|
||||||
|
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
|
||||||
|
S_shared = T.alloc_shared([H_per_block, BI], dtype)
|
||||||
|
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
alpha = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
m_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
|
||||||
|
T.fill(acc_o, 0)
|
||||||
|
T.fill(sumexp, 0)
|
||||||
|
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
|
||||||
|
|
||||||
|
b_i, g_i = by, bz
|
||||||
|
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
|
||||||
|
q_i = s_i
|
||||||
|
max_kv_i = q_i
|
||||||
|
|
||||||
|
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
|
||||||
|
H1 = H0 + H_per_block
|
||||||
|
|
||||||
|
T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)
|
||||||
|
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
|
||||||
|
|
||||||
|
for i_i in T.Pipelined(NI, num_stages=num_stages):
|
||||||
|
|
||||||
|
for bi_i in T.Parallel(BI):
|
||||||
|
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] >= 0
|
||||||
|
|
||||||
|
for bi_i, d_i in T.Parallel(BI, D):
|
||||||
|
KV_shared[bi_i, d_i] = KV[
|
||||||
|
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i
|
||||||
|
]
|
||||||
|
for bi_i, d_i in T.Parallel(BI, D_tail):
|
||||||
|
K_tail_shared[bi_i, d_i] = KV[
|
||||||
|
b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i
|
||||||
|
]
|
||||||
|
|
||||||
|
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||||
|
acc_s[h_i, bi_i] = T.if_then_else(
|
||||||
|
mask[bi_i], 0, -T.infinity(acc_s.dtype)
|
||||||
|
)
|
||||||
|
T.gemm(
|
||||||
|
Q_shared,
|
||||||
|
KV_shared,
|
||||||
|
acc_s,
|
||||||
|
transpose_B=True,
|
||||||
|
policy=T.GemmWarpPolicy.FullCol,
|
||||||
|
)
|
||||||
|
T.gemm(
|
||||||
|
Q_tail_shared,
|
||||||
|
K_tail_shared,
|
||||||
|
acc_s,
|
||||||
|
transpose_B=True,
|
||||||
|
policy=T.GemmWarpPolicy.FullCol,
|
||||||
|
)
|
||||||
|
T.copy(m_i, m_i_prev)
|
||||||
|
T.reduce_max(acc_s, m_i, dim=1, clear=False)
|
||||||
|
for h_i in T.Parallel(H_per_block):
|
||||||
|
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
|
||||||
|
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||||
|
acc_s[h_i, bi_i] = T.exp2(
|
||||||
|
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
|
||||||
|
)
|
||||||
|
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
|
||||||
|
for h_i in T.Parallel(H_per_block):
|
||||||
|
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
|
||||||
|
for h_i, d_i in T.Parallel(H_per_block, D):
|
||||||
|
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
|
||||||
|
|
||||||
|
T.copy(acc_s, S_shared)
|
||||||
|
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
|
||||||
|
|
||||||
|
# Rescale
|
||||||
|
for h_i, d_i in T.Parallel(H_per_block, D):
|
||||||
|
acc_o[h_i, d_i] /= sumexp[h_i]
|
||||||
|
for h_i in T.Parallel(H_per_block):
|
||||||
|
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
|
||||||
|
|
||||||
|
T.copy(acc_o, O_shared)
|
||||||
|
T.copy(acc_o, Output[b_i, s_i, H0:H1, :])
|
||||||
|
|
||||||
|
return main
|
||||||
|
|
||||||
|
|
||||||
|
@tilelang.jit(
|
||||||
|
out_idx=[-1],
|
||||||
|
compile_flags=[
|
||||||
|
"-O3",
|
||||||
|
"-Wno-deprecated-declarations",
|
||||||
|
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||||
|
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
|
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||||
|
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||||
|
"--expt-relaxed-constexpr",
|
||||||
|
"--expt-extended-lambda",
|
||||||
|
"--ptxas-options=-v,--register-usage-level=10",
|
||||||
|
"-DNDEBUG",
|
||||||
|
],
|
||||||
|
) # type: ignore
|
||||||
|
def sparse_attention_fwd_kernel_v2(
|
||||||
|
num_heads: int,
|
||||||
|
dim: int,
|
||||||
|
tail_dim: int,
|
||||||
|
topk: int,
|
||||||
|
*,
|
||||||
|
kv_group: int = 1,
|
||||||
|
sm_scale: Optional[float] = None,
|
||||||
|
block_I: int = 64,
|
||||||
|
):
|
||||||
|
assert dim == tilelang.math.next_power_of_2(
|
||||||
|
dim
|
||||||
|
), f"haven't check padding correctness yet, dim={dim}"
|
||||||
|
assert tail_dim == tilelang.math.next_power_of_2(
|
||||||
|
tail_dim
|
||||||
|
), f"haven't check padding correctness yet, dim={tail_dim}"
|
||||||
|
assert (
|
||||||
|
topk % block_I == 0
|
||||||
|
), "otherwise will load some index=0 thus causing wrong kv to be loaded"
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
|
||||||
|
else:
|
||||||
|
sm_scale = sm_scale * 1.44269504 # log2(e)
|
||||||
|
threads = 384
|
||||||
|
|
||||||
|
batch = T.symbolic("batch")
|
||||||
|
qo_len = T.symbolic("seq_len")
|
||||||
|
num_pages = T.symbolic("num_pages")
|
||||||
|
|
||||||
|
q_shape = [batch, qo_len, num_heads, dim + tail_dim]
|
||||||
|
kv_shape = [batch, num_pages, kv_group, dim + tail_dim]
|
||||||
|
o_shape = [batch, qo_len, num_heads, dim]
|
||||||
|
indices_shape = [batch, qo_len, kv_group, topk]
|
||||||
|
|
||||||
|
indices_dtype = "int32"
|
||||||
|
dtype = "bfloat16"
|
||||||
|
accum_dtype = "float"
|
||||||
|
|
||||||
|
H = num_heads
|
||||||
|
padded_H = max(tilelang.math.next_power_of_2(num_heads), 16)
|
||||||
|
if padded_H != H:
|
||||||
|
assert kv_group == 1
|
||||||
|
BI = block_I
|
||||||
|
NI = tilelang.cdiv(topk, block_I)
|
||||||
|
assert NI % 2 == 0, "NI should be a multiple of 2"
|
||||||
|
D = dim
|
||||||
|
D_tail = tail_dim
|
||||||
|
if num_heads > 64:
|
||||||
|
assert num_heads % 64 == 0, "head_kv should be a multiple of 64"
|
||||||
|
REPLICATE_H = num_heads // 64
|
||||||
|
else:
|
||||||
|
REPLICATE_H = 1
|
||||||
|
|
||||||
|
H_per_block = padded_H if REPLICATE_H == 1 else 64
|
||||||
|
|
||||||
|
@T.prim_func
|
||||||
|
def main(
|
||||||
|
Q: T.Tensor(q_shape, dtype), # type: ignore
|
||||||
|
KV: T.Tensor(kv_shape, dtype), # type: ignore
|
||||||
|
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
|
||||||
|
Output: T.Tensor(o_shape, dtype), # type: ignore
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Q: [b, qo_len, H, D + D_tail] (bfloat16)
|
||||||
|
KV: [b, num_pages, kv_group, D + D_tail] (bfloat16)
|
||||||
|
Indices: [b, qo_len, kv_group, topk] (int32)
|
||||||
|
"""
|
||||||
|
|
||||||
|
with T.Kernel(qo_len * REPLICATE_H, batch, 1, threads=threads) as (bx, by, bz): # type: ignore
|
||||||
|
Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
|
||||||
|
Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
|
||||||
|
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
|
||||||
|
KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype)
|
||||||
|
KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype)
|
||||||
|
KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype)
|
||||||
|
KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype)
|
||||||
|
K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype)
|
||||||
|
K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype)
|
||||||
|
O_shared_l = Q_shared_l
|
||||||
|
O_shared_r = Q_shared_r
|
||||||
|
is_kv_valid_0 = T.alloc_shared([BI], "bool", scope="shared")
|
||||||
|
is_kv_valid_1 = T.alloc_shared([BI], "bool", scope="shared")
|
||||||
|
|
||||||
|
acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
|
||||||
|
acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
|
||||||
|
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
|
||||||
|
S_shared = T.alloc_shared([H_per_block, BI], dtype)
|
||||||
|
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype)
|
||||||
|
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared")
|
||||||
|
alpha_local = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
m_i = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
|
||||||
|
indices_local = T.alloc_local([1], indices_dtype)
|
||||||
|
indices_tmp = T.alloc_local([1], indices_dtype)
|
||||||
|
|
||||||
|
bar_q = T.alloc_barrier(arrive_count=384)
|
||||||
|
bar_k_0_ready = T.alloc_barrier(arrive_count=128)
|
||||||
|
bar_k_1_ready = T.alloc_barrier(arrive_count=128)
|
||||||
|
bar_k_0_free = T.alloc_barrier(arrive_count=256)
|
||||||
|
bar_k_1_free = T.alloc_barrier(arrive_count=256)
|
||||||
|
bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
|
||||||
|
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
|
||||||
|
|
||||||
|
bar_0_128 = T.alloc_barrier(arrive_count=128)
|
||||||
|
bar_1_128 = T.alloc_barrier(arrive_count=128)
|
||||||
|
bar_2_128 = T.alloc_barrier(arrive_count=128)
|
||||||
|
bar_final = T.alloc_barrier(arrive_count=128)
|
||||||
|
|
||||||
|
b_i, g_i = by, bz
|
||||||
|
s_i = bx if REPLICATE_H == 1 else bx // REPLICATE_H
|
||||||
|
|
||||||
|
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
|
||||||
|
H1 = H0 + H_per_block
|
||||||
|
|
||||||
|
tx = T.get_thread_binding()
|
||||||
|
|
||||||
|
T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)
|
||||||
|
T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)
|
||||||
|
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
|
||||||
|
T.barrier_arrive(bar_q)
|
||||||
|
|
||||||
|
if tx < 128:
|
||||||
|
T.set_max_nreg(240, 1)
|
||||||
|
T.fill(sumexp, 0)
|
||||||
|
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
|
||||||
|
T.fill(acc_o_l, 0)
|
||||||
|
T.barrier_wait(bar_q, 0)
|
||||||
|
|
||||||
|
for i_i in T.serial(T.ceildiv(NI, 2)):
|
||||||
|
# Buffer 0
|
||||||
|
# with sync_at(bar_0_128, 0):
|
||||||
|
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
|
||||||
|
T.barrier_arrive(bar_0_128)
|
||||||
|
T.barrier_wait(bar_0_128, 0)
|
||||||
|
|
||||||
|
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||||
|
acc_s[h_i, bi_i] = T.if_then_else(
|
||||||
|
is_kv_valid_0[bi_i], 0, -T.infinity(acc_s.dtype)
|
||||||
|
)
|
||||||
|
T.gemm(
|
||||||
|
Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1
|
||||||
|
)
|
||||||
|
T.gemm(
|
||||||
|
Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1
|
||||||
|
)
|
||||||
|
T.gemm(
|
||||||
|
Q_tail_shared,
|
||||||
|
K_tail_shared_0,
|
||||||
|
acc_s,
|
||||||
|
transpose_B=True,
|
||||||
|
wg_wait=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
T.wait_wgmma(0)
|
||||||
|
|
||||||
|
if i_i != 0:
|
||||||
|
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||||
|
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
|
||||||
|
|
||||||
|
T.copy(m_i, m_i_prev)
|
||||||
|
T.reduce_max(acc_s, m_i, dim=1, clear=False)
|
||||||
|
for h_i in T.Parallel(H_per_block):
|
||||||
|
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
|
||||||
|
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||||
|
acc_s[h_i, bi_i] = T.exp2(
|
||||||
|
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
|
||||||
|
)
|
||||||
|
T.reduce_sum(
|
||||||
|
acc_s, sumexp_i, dim=1
|
||||||
|
) # is this a accumulate operator?
|
||||||
|
for h_i in T.Parallel(H_per_block):
|
||||||
|
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
|
||||||
|
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||||
|
acc_o_l[h_i, d_i] *= alpha_local[h_i]
|
||||||
|
T.copy(alpha_local, alpha_shared)
|
||||||
|
|
||||||
|
T.copy(acc_s, S_shared)
|
||||||
|
T.gemm(S_shared, KV_shared_0_l, acc_o_l)
|
||||||
|
|
||||||
|
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||||
|
T.barrier_arrive(bar_k_0_free[0])
|
||||||
|
|
||||||
|
# Buffer 1
|
||||||
|
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
|
||||||
|
T.barrier_arrive(bar_0_128)
|
||||||
|
T.barrier_wait(bar_0_128, 1)
|
||||||
|
|
||||||
|
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||||
|
acc_s[h_i, bi_i] = T.if_then_else(
|
||||||
|
is_kv_valid_1[bi_i], 0, -T.infinity(acc_s.dtype)
|
||||||
|
)
|
||||||
|
T.gemm(
|
||||||
|
Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1
|
||||||
|
)
|
||||||
|
T.gemm(
|
||||||
|
Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1
|
||||||
|
)
|
||||||
|
T.gemm(
|
||||||
|
Q_tail_shared,
|
||||||
|
K_tail_shared_1,
|
||||||
|
acc_s,
|
||||||
|
transpose_B=True,
|
||||||
|
wg_wait=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
T.wait_wgmma(0)
|
||||||
|
|
||||||
|
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||||
|
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
|
||||||
|
|
||||||
|
T.copy(m_i, m_i_prev)
|
||||||
|
T.reduce_max(acc_s, m_i, dim=1, clear=False)
|
||||||
|
for h_i in T.Parallel(H_per_block):
|
||||||
|
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
|
||||||
|
for h_i, bi_i in T.Parallel(H_per_block, BI):
|
||||||
|
acc_s[h_i, bi_i] = T.exp2(
|
||||||
|
acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale
|
||||||
|
)
|
||||||
|
T.reduce_sum(
|
||||||
|
acc_s, sumexp_i, dim=1
|
||||||
|
) # is this a accumulate operator?
|
||||||
|
for h_i in T.Parallel(H_per_block):
|
||||||
|
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
|
||||||
|
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||||
|
acc_o_l[h_i, d_i] *= alpha_local[h_i]
|
||||||
|
T.copy(alpha_local, alpha_shared)
|
||||||
|
|
||||||
|
T.copy(acc_s, S_shared)
|
||||||
|
T.gemm(S_shared, KV_shared_1_l, acc_o_l)
|
||||||
|
|
||||||
|
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||||
|
T.barrier_arrive(bar_k_1_free[0])
|
||||||
|
|
||||||
|
# Rescale
|
||||||
|
for h_i in T.Parallel(H_per_block):
|
||||||
|
sum_exp_shared[h_i] = sumexp[h_i]
|
||||||
|
T.barrier_arrive(bar_final)
|
||||||
|
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||||
|
acc_o_l[h_i, d_i] /= sumexp[h_i]
|
||||||
|
for h_i in T.Parallel(H_per_block):
|
||||||
|
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
|
||||||
|
T.copy(acc_o_l, O_shared_l)
|
||||||
|
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])
|
||||||
|
elif tx >= 128 and tx < 256:
|
||||||
|
# T.set_max_nreg(168, 1)
|
||||||
|
T.fill(acc_o_r, 0)
|
||||||
|
for i_i in T.serial(T.ceildiv(NI, 2)):
|
||||||
|
# Buffer 0
|
||||||
|
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||||
|
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
|
||||||
|
T.barrier_arrive(bar_1_128)
|
||||||
|
T.barrier_wait(bar_1_128, 0)
|
||||||
|
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||||
|
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
|
||||||
|
T.gemm(S_shared, KV_shared_0_r, acc_o_r)
|
||||||
|
T.barrier_arrive(bar_k_0_free[0])
|
||||||
|
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||||
|
|
||||||
|
# Buffer 1
|
||||||
|
T.barrier_arrive(bar_sScale_and_sS_ready)
|
||||||
|
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
|
||||||
|
T.barrier_arrive(bar_1_128)
|
||||||
|
T.barrier_wait(bar_1_128, 1)
|
||||||
|
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||||
|
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
|
||||||
|
T.gemm(S_shared, KV_shared_1_r, acc_o_r)
|
||||||
|
T.barrier_arrive(bar_k_1_free[0])
|
||||||
|
if i_i != T.ceildiv(NI, 2) - 1:
|
||||||
|
T.barrier_arrive(bar_sScale_and_sS_free)
|
||||||
|
|
||||||
|
# Rescale
|
||||||
|
T.barrier_wait(bar_final, 0)
|
||||||
|
for h_i, d_i in T.Parallel(H_per_block, D // 2):
|
||||||
|
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
|
||||||
|
|
||||||
|
T.copy(acc_o_r, O_shared_r)
|
||||||
|
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])
|
||||||
|
elif tx >= 256:
|
||||||
|
# producer
|
||||||
|
T.set_max_nreg(80, 0)
|
||||||
|
indices_local[0] = 0
|
||||||
|
for i_i in T.serial(T.ceildiv(NI, 2)):
|
||||||
|
# Buffer 0
|
||||||
|
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
|
||||||
|
T.barrier_arrive(bar_2_128)
|
||||||
|
T.barrier_wait(bar_2_128, 0)
|
||||||
|
|
||||||
|
for r in T.serial(4):
|
||||||
|
indices_tmp[0] = Indices[
|
||||||
|
b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8
|
||||||
|
]
|
||||||
|
is_kv_valid_0[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
|
||||||
|
if is_kv_valid_0[r * 16 + (tx - 256) // 8]:
|
||||||
|
indices_local[0] = indices_tmp[0]
|
||||||
|
|
||||||
|
with T.attr("default", "async_scope", 1): # type: ignore
|
||||||
|
for u in T.serial(4):
|
||||||
|
for v in T.vectorized(8):
|
||||||
|
KV_shared_0_l[
|
||||||
|
r * 16 + (tx - 256) // 8,
|
||||||
|
64 * u + (tx - 256) % 8 * 8 + v,
|
||||||
|
] = KV[
|
||||||
|
b_i,
|
||||||
|
indices_local[0],
|
||||||
|
g_i,
|
||||||
|
64 * u + (tx - 256) % 8 * 8 + v,
|
||||||
|
]
|
||||||
|
KV_shared_0_r[
|
||||||
|
r * 16 + (tx - 256) // 8,
|
||||||
|
64 * u + (tx - 256) % 8 * 8 + v,
|
||||||
|
] = KV[
|
||||||
|
b_i,
|
||||||
|
indices_local[0],
|
||||||
|
g_i,
|
||||||
|
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
|
||||||
|
]
|
||||||
|
with T.attr("default", "async_scope", 1): # type: ignore
|
||||||
|
for v in T.vectorized(8):
|
||||||
|
K_tail_shared_0[
|
||||||
|
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
|
||||||
|
] = KV[
|
||||||
|
b_i,
|
||||||
|
indices_local[0],
|
||||||
|
g_i,
|
||||||
|
D + (tx - 256) % 8 * 8 + v,
|
||||||
|
]
|
||||||
|
|
||||||
|
T.cp_async_barrier_noinc(bar_k_0_ready[0])
|
||||||
|
|
||||||
|
# Buffer 1
|
||||||
|
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
|
||||||
|
T.barrier_arrive(bar_2_128)
|
||||||
|
T.barrier_wait(bar_2_128, 1)
|
||||||
|
|
||||||
|
for r in T.serial(4):
|
||||||
|
indices_tmp[0] = Indices[
|
||||||
|
b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8
|
||||||
|
]
|
||||||
|
is_kv_valid_1[r * 16 + (tx - 256) // 8] = indices_tmp[0] >= 0
|
||||||
|
if is_kv_valid_1[r * 16 + (tx - 256) // 8]:
|
||||||
|
indices_local[0] = indices_tmp[0]
|
||||||
|
|
||||||
|
with T.attr("default", "async_scope", 1): # type: ignore
|
||||||
|
for u in T.serial(4):
|
||||||
|
for v in T.vectorized(8):
|
||||||
|
KV_shared_1_l[
|
||||||
|
r * 16 + (tx - 256) // 8,
|
||||||
|
64 * u + (tx - 256) % 8 * 8 + v,
|
||||||
|
] = KV[
|
||||||
|
b_i,
|
||||||
|
indices_local[0],
|
||||||
|
g_i,
|
||||||
|
64 * u + (tx - 256) % 8 * 8 + v,
|
||||||
|
]
|
||||||
|
KV_shared_1_r[
|
||||||
|
r * 16 + (tx - 256) // 8,
|
||||||
|
64 * u + (tx - 256) % 8 * 8 + v,
|
||||||
|
] = KV[
|
||||||
|
b_i,
|
||||||
|
indices_local[0],
|
||||||
|
g_i,
|
||||||
|
D // 2 + 64 * u + (tx - 256) % 8 * 8 + v,
|
||||||
|
]
|
||||||
|
with T.attr("default", "async_scope", 1): # type: ignore
|
||||||
|
for v in T.vectorized(8):
|
||||||
|
K_tail_shared_1[
|
||||||
|
r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v
|
||||||
|
] = KV[
|
||||||
|
b_i,
|
||||||
|
indices_local[0],
|
||||||
|
g_i,
|
||||||
|
D + (tx - 256) % 8 * 8 + v,
|
||||||
|
]
|
||||||
|
|
||||||
|
T.cp_async_barrier_noinc(bar_k_1_ready[0])
|
||||||
|
|
||||||
|
return main
|
||||||
|
|
||||||
|
def tilelang_sparse_fwd(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
d_v: int = 512,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert q.dim() == 3 and kv.dim() == 3 and indices.dim() == 3
|
||||||
|
num_heads = q.shape[1]
|
||||||
|
dim = q.shape[2]
|
||||||
|
tail_dim = dim - d_v
|
||||||
|
topk = indices.shape[-1]
|
||||||
|
assert topk == 2048
|
||||||
|
# NOTE(dark): v2 offers better performance than v1
|
||||||
|
kernel = sparse_attention_fwd_kernel_v2(
|
||||||
|
num_heads, d_v, tail_dim, topk, sm_scale=sm_scale
|
||||||
|
)
|
||||||
|
return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore
|
||||||
|
'''
|
||||||
|
def act_quant(
|
||||||
|
x: torch.Tensor,
|
||||||
|
block_size: int = 128,
|
||||||
|
scale_fmt: Optional[str] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
PyTorch fallback for act_quant
|
||||||
|
Block-wise FP8 E4M3 quantization
|
||||||
|
"""
|
||||||
|
if not x.is_contiguous():
|
||||||
|
x = x.contiguous()
|
||||||
|
|
||||||
|
N = x.size(-1)
|
||||||
|
assert N % block_size == 0, f"Last dim {N} must be divisible by block_size={block_size}"
|
||||||
|
|
||||||
|
# Reshape to blocks
|
||||||
|
x_2d = x.view(-1, N)
|
||||||
|
x_blocks = x_2d.view(-1, block_size)
|
||||||
|
|
||||||
|
# Compute absmax per block
|
||||||
|
amax = x_blocks.abs().amax(dim=1, keepdim=True).clamp(min=1e-4)
|
||||||
|
|
||||||
|
# FP8 E4M3 max value is ~448
|
||||||
|
fp8_max = 448.0
|
||||||
|
scale = amax / fp8_max
|
||||||
|
|
||||||
|
if scale_fmt is not None:
|
||||||
|
# Simulate rounded scale (power-of-2 rounding)
|
||||||
|
scale = torch.round(scale * 256) / 256
|
||||||
|
|
||||||
|
# Quantize and clamp
|
||||||
|
y_blocks = torch.clamp(torch.round(x_blocks / scale), -fp8_max, fp8_max)
|
||||||
|
|
||||||
|
# Convert to FP8
|
||||||
|
q = y_blocks.view_as(x_2d).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
# Reshape scale
|
||||||
|
s = scale.view(x_2d.size(0), N // block_size).to(torch.float32)
|
||||||
|
s = s.view(*x.shape[:-1], N // block_size)
|
||||||
|
|
||||||
|
return q.view_as(x), s
|
||||||
65
python/sglang/srt/layers/attention/nsa/topk.py
Normal file
65
python/sglang/srt/layers/attention/nsa/topk.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import align
|
||||||
|
|
||||||
|
# NOTE(dark): flashmla P requires `params.topk % (2*B_TOPK) == 0`,
|
||||||
|
# where `B_TOPK=64`. So we align to 128 by default.
|
||||||
|
|
||||||
|
_TOPK_ALIGNMENT = 128
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(dark): maybe this torch_op can support torch.compile
|
||||||
|
def _fast_topk_torch(
|
||||||
|
input: torch.Tensor, seq_lens: torch.Tensor, topk: int, alignment: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Fallback to torch.topk
|
||||||
|
bs, max_seq_len = input.shape
|
||||||
|
assert len(seq_lens) == bs
|
||||||
|
# set those out-of-bound input to -inf
|
||||||
|
padded_max_seq_len = align(max_seq_len, alignment)
|
||||||
|
positions = torch.arange(
|
||||||
|
padded_max_seq_len, device=input.device, dtype=seq_lens.dtype
|
||||||
|
)
|
||||||
|
positions = positions.unsqueeze(0).expand(bs, -1)
|
||||||
|
mask = positions >= seq_lens.unsqueeze(1)
|
||||||
|
|
||||||
|
# NOTE(dark): just return all valid indices as an optimization
|
||||||
|
if padded_max_seq_len <= topk:
|
||||||
|
return positions.masked_fill(mask, -1)
|
||||||
|
|
||||||
|
assert topk % alignment == 0
|
||||||
|
|
||||||
|
# in-place operation: mask invalid inputs to -inf
|
||||||
|
input = input.masked_fill_(mask[:, :max_seq_len], float("-inf"))
|
||||||
|
result = input.topk(topk, dim=-1, sorted=True)
|
||||||
|
return result.indices.masked_fill_(mask[:, :topk], -1)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_topk_impl(
|
||||||
|
input: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
alignment: int = _TOPK_ALIGNMENT,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return _fast_topk_torch(input, seq_lens, topk, alignment)
|
||||||
|
|
||||||
|
|
||||||
|
def fast_topk_transform_fused_cuda(
|
||||||
|
input: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
dst_page_table: torch.Tensor,
|
||||||
|
src_page_table: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
alignment: int = _TOPK_ALIGNMENT,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.attention.nsa.cuda import fast_topk_transform
|
||||||
|
|
||||||
|
assert topk == 2048 and topk % alignment == 0
|
||||||
|
return fast_topk_transform(
|
||||||
|
score=input,
|
||||||
|
lengths=seq_lens,
|
||||||
|
dst_page_table=dst_page_table,
|
||||||
|
src_page_table=src_page_table,
|
||||||
|
cu_seqlens=cu_seqlens_q,
|
||||||
|
)
|
||||||
144
python/sglang/srt/layers/attention/nsa/transform_index.py
Normal file
144
python/sglang/srt/layers/attention/nsa/transform_index.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
def transform_index_page_table_prefill(**kwargs):
|
||||||
|
return transform_index_page_table_prefill_ref(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_index_page_table_decode(**kwargs):
|
||||||
|
return transform_index_page_table_decode_ref(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def transform_index_page_table_decode_kernel(
|
||||||
|
page_table_ptr: torch.Tensor,
|
||||||
|
topk_indices_ptr: torch.Tensor,
|
||||||
|
result_ptr: torch.Tensor,
|
||||||
|
page_size: tl.constexpr,
|
||||||
|
max_seqlen_k: tl.constexpr,
|
||||||
|
):
|
||||||
|
TOPK: tl.constexpr = 2048
|
||||||
|
req_id = tl.program_id(0)
|
||||||
|
page_table_ptr = page_table_ptr + req_id * max_seqlen_k
|
||||||
|
topk_indices_ptr = topk_indices_ptr + req_id * TOPK
|
||||||
|
result_ptr = result_ptr + req_id * TOPK
|
||||||
|
|
||||||
|
offset = tl.arange(0, TOPK) # topk should be 2048
|
||||||
|
loaded_topk_indices = tl.load(topk_indices_ptr + offset)
|
||||||
|
mask = loaded_topk_indices >= 0
|
||||||
|
loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)
|
||||||
|
tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)
|
||||||
|
tl.store(result_ptr + offset, -1, mask=~mask)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_index_page_table_decode_fast(
|
||||||
|
page_table: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
result: Optional[torch.Tensor] = None,
|
||||||
|
page_size: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Transform the page table according to topk indices for sparse topk attention.
|
||||||
|
Args:
|
||||||
|
page_table: [qo_len, max_seqlen_k], the original page table
|
||||||
|
topk_indices: [qo_len, topk], the topk indices for each query position
|
||||||
|
Returns:
|
||||||
|
transformed_page_table: [qo_len, topk], the transformed page table
|
||||||
|
For out-of-bound indices in topk_indices, this should be filled with -1.
|
||||||
|
"""
|
||||||
|
assert page_size == 1
|
||||||
|
assert page_table.shape[0] == topk_indices.shape[0]
|
||||||
|
assert topk_indices.shape[1] == 2048
|
||||||
|
qo_len = topk_indices.shape[0]
|
||||||
|
max_seqlen_k = page_table.shape[1]
|
||||||
|
if result is None:
|
||||||
|
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||||
|
# Launch triton kernel
|
||||||
|
grid = (qo_len,)
|
||||||
|
transform_index_page_table_decode_kernel[grid](
|
||||||
|
page_table,
|
||||||
|
topk_indices,
|
||||||
|
result,
|
||||||
|
page_size,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def transform_index_page_table_prefill_fast(
|
||||||
|
page_table: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
extend_lens_cpu: List[int],
|
||||||
|
page_size: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO(baizhou): can be implemented with another triton kernel
|
||||||
|
assert page_size == 1
|
||||||
|
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||||
|
assert len(extend_lens_cpu) == page_table.shape[0]
|
||||||
|
offset = 0
|
||||||
|
for i, l in enumerate(extend_lens_cpu):
|
||||||
|
transform_index_page_table_decode_fast(
|
||||||
|
page_table[i].unsqueeze(0).expand(l, -1),
|
||||||
|
topk_indices[offset : offset + l],
|
||||||
|
result=result[offset : offset + l],
|
||||||
|
)
|
||||||
|
offset += l
|
||||||
|
assert offset == topk_indices.shape[0]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def transform_index_page_table_decode_ref(
|
||||||
|
page_table: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
result: Optional[torch.Tensor] = None,
|
||||||
|
page_size: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert page_size == 1
|
||||||
|
assert page_table.shape[0] == topk_indices.shape[0]
|
||||||
|
if result is None:
|
||||||
|
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||||
|
assert result.shape == topk_indices.shape
|
||||||
|
torch.gather(
|
||||||
|
page_table,
|
||||||
|
dim=1,
|
||||||
|
index=topk_indices.clamp(min=0).long(),
|
||||||
|
out=result,
|
||||||
|
)
|
||||||
|
result[topk_indices < 0] = -1
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def transform_index_page_table_prefill_ref(
|
||||||
|
page_table: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
extend_lens_cpu: List[int],
|
||||||
|
page_size: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert page_size == 1
|
||||||
|
result = torch.empty_like(topk_indices, dtype=torch.int32)
|
||||||
|
assert len(extend_lens_cpu) == page_table.shape[0]
|
||||||
|
offset = 0
|
||||||
|
for i, l in enumerate(extend_lens_cpu):
|
||||||
|
transform_index_page_table_decode_ref(
|
||||||
|
page_table[i].unsqueeze(0).expand(l, -1),
|
||||||
|
topk_indices[offset : offset + l],
|
||||||
|
result=result[offset : offset + l],
|
||||||
|
)
|
||||||
|
offset += l
|
||||||
|
assert offset == topk_indices.shape[0]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bs, topk, max_seqlen = 10, 2048, 3000
|
||||||
|
page_table = torch.randint(0, 100, (bs, max_seqlen), device="cuda")
|
||||||
|
topk_indices = torch.full((bs, topk), -1, device="cuda")
|
||||||
|
topk_indices[:, :1600] = torch.arange(1600).unsqueeze(0).repeat(bs, 1)
|
||||||
|
ref_result = transform_index_page_table_decode_ref(page_table, topk_indices)
|
||||||
|
result = transform_index_page_table_decode_fast(page_table, topk_indices)
|
||||||
|
assert torch.all(result == ref_result)
|
||||||
|
print("Passed")
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.weights_proj = nn.Linear(d_in, 1024)
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.softmax_scale = softmax_scale
|
||||||
|
|
||||||
|
def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||||
|
weights = self.weights_proj(x)
|
||||||
|
weights = weights * self.n_heads**-0.5
|
||||||
|
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
||||||
|
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
||||||
|
return weights
|
||||||
|
|
||||||
|
def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
|
||||||
|
weights = self.weights_proj(x)
|
||||||
|
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
||||||
|
scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
|
||||||
|
weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
|
||||||
|
x = torch.randn(128, 2048) # batch=128, d_in=2048
|
||||||
|
q_scale = torch.randn(128, 1)
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(1000):
|
||||||
|
out_orig = model._get_logits_head_gate_orig(x, q_scale)
|
||||||
|
print("Original version time:", time.time() - start)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(1000):
|
||||||
|
out_opt = model._get_logits_head_gate_opt(x, q_scale)
|
||||||
|
print("Optimized version time:", time.time() - start)
|
||||||
|
|
||||||
|
print("Difference:", (out_orig - out_opt).abs().max().item())
|
||||||
|
assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Original version time: 0.49235057830810547
|
||||||
|
Optimized version time: 0.4087331295013428
|
||||||
|
Difference: 1.4901161193847656e-08
|
||||||
|
"""
|
||||||
32
python/sglang/srt/layers/attention/nsa/utils.py
Normal file
32
python/sglang/srt/layers/attention/nsa/utils.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# temp NSA debugging environ
|
||||||
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
|
NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
|
||||||
|
NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
|
||||||
|
NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
|
||||||
|
|
||||||
|
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 = get_bool_env_var(
|
||||||
|
"SGLANG_NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8", "true"
|
||||||
|
)
|
||||||
|
NSA_KV_CACHE_STORE_FP8 = get_bool_env_var("SGLANG_NSA_KV_CACHE_STORE_FP8", "false")
|
||||||
|
NSA_QUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_QUANT_K_CACHE_FAST", "false")
|
||||||
|
NSA_DEQUANT_K_CACHE_FAST = get_bool_env_var("SGLANG_NSA_DEQUANT_K_CACHE_FAST", "false")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_bool_env_vars():
|
||||||
|
msg = ""
|
||||||
|
for k, v in globals().items():
|
||||||
|
if k.startswith("NSA_") and isinstance(v, bool):
|
||||||
|
msg += f"{k}={v} "
|
||||||
|
print(msg, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
_print_bool_env_vars()
|
||||||
|
|
||||||
|
|
||||||
|
if not NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8:
|
||||||
|
assert not NSA_KV_CACHE_STORE_FP8
|
||||||
|
|
||||||
|
|
||||||
|
def compute_nsa_seqlens(original_seq_lens, nsa_index_topk: int):
|
||||||
|
return original_seq_lens.clamp(max=nsa_index_topk)
|
||||||
869
python/sglang/srt/layers/attention/nsa_backend.py
Normal file
869
python/sglang/srt/layers/attention/nsa_backend.py
Normal file
@@ -0,0 +1,869 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TypeAlias,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
|
|
||||||
|
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
|
||||||
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache
|
||||||
|
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
||||||
|
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
||||||
|
from sglang.srt.layers.attention.nsa.topk import (
|
||||||
|
fast_topk_impl,
|
||||||
|
fast_topk_transform_fused_cuda,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.attention.nsa.transform_index import (
|
||||||
|
transform_index_page_table_decode,
|
||||||
|
transform_index_page_table_prefill,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.attention.nsa.utils import (
|
||||||
|
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||||
|
NSA_FUSE_TOPK,
|
||||||
|
NSA_KV_CACHE_STORE_FP8,
|
||||||
|
compute_nsa_seqlens,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
from sglang.srt.two_batch_overlap import global_server_args_dict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NSAFlashMLAMetadata:
|
||||||
|
"""Metadata only needed by FlashMLA"""
|
||||||
|
|
||||||
|
flashmla_metadata: torch.Tensor
|
||||||
|
num_splits: torch.Tensor
|
||||||
|
|
||||||
|
def slice(self, sli):
|
||||||
|
return NSAFlashMLAMetadata(
|
||||||
|
flashmla_metadata=self.flashmla_metadata,
|
||||||
|
num_splits=self.num_splits[sli],
|
||||||
|
)
|
||||||
|
|
||||||
|
def copy_(self, other: "NSAFlashMLAMetadata"):
|
||||||
|
self.flashmla_metadata.copy_(other.flashmla_metadata)
|
||||||
|
self.num_splits.copy_(other.num_splits)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NSAMetadata:
|
||||||
|
page_size: int
|
||||||
|
|
||||||
|
# Sequence lengths for the forward batch
|
||||||
|
cache_seqlens_int32: torch.Tensor
|
||||||
|
# Maximum sequence length for query
|
||||||
|
max_seq_len_q: int
|
||||||
|
# Maximum sequence length for key
|
||||||
|
max_seq_len_k: int
|
||||||
|
# Cumulative sequence lengths for query
|
||||||
|
cu_seqlens_q: torch.Tensor
|
||||||
|
# Cumulative sequence lengths for key
|
||||||
|
cu_seqlens_k: torch.Tensor
|
||||||
|
# Page table, the index of KV Cache Tables/Blocks
|
||||||
|
# this table is always with page_size = 1
|
||||||
|
page_table_1: torch.Tensor
|
||||||
|
|
||||||
|
# NOTE(dark): This will property be used in:
|
||||||
|
# 1. dense decode/prefill, we use paged flash attention, need real_page_table
|
||||||
|
# 2. sparse decode/prefill, indexer need real_page_table to compute the score
|
||||||
|
real_page_table: torch.Tensor
|
||||||
|
|
||||||
|
# NSA metadata (nsa prefill are expanded)
|
||||||
|
nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
|
||||||
|
nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
|
||||||
|
nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
|
||||||
|
nsa_extend_seq_lens_list: List[int]
|
||||||
|
nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
|
||||||
|
nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
|
||||||
|
|
||||||
|
flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NSAIndexerMetadata(BaseIndexerMetadata):
|
||||||
|
attn_metadata: NSAMetadata
|
||||||
|
|
||||||
|
def get_seqlens_int32(self) -> torch.Tensor:
|
||||||
|
return self.attn_metadata.cache_seqlens_int32
|
||||||
|
|
||||||
|
def get_page_table_64(self) -> torch.Tensor:
|
||||||
|
return self.attn_metadata.real_page_table
|
||||||
|
|
||||||
|
def get_seqlens_expanded(self) -> torch.Tensor:
|
||||||
|
return self.attn_metadata.nsa_seqlens_expanded
|
||||||
|
|
||||||
|
def topk_transform(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if not NSA_FUSE_TOPK:
|
||||||
|
return fast_topk_impl(logits, self.get_seqlens_expanded(), topk)
|
||||||
|
|
||||||
|
# NOTE(dark): if fused, we return a transformed page table directly
|
||||||
|
dst_page_table = torch.empty(
|
||||||
|
(logits.shape[0], topk), dtype=torch.int32, device=logits.device
|
||||||
|
)
|
||||||
|
fast_topk_transform_fused_cuda(
|
||||||
|
input=logits,
|
||||||
|
seq_lens=self.get_seqlens_expanded(),
|
||||||
|
topk=topk,
|
||||||
|
dst_page_table=dst_page_table,
|
||||||
|
src_page_table=self.attn_metadata.page_table_1,
|
||||||
|
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
|
||||||
|
)
|
||||||
|
return dst_page_table
|
||||||
|
|
||||||
|
|
||||||
|
def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert seqlens.dtype == torch.int32 and seqlens.is_cuda
|
||||||
|
return torch.nn.functional.pad(
|
||||||
|
torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_NSA_IMPL_T: TypeAlias = Literal[
|
||||||
|
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
|
||||||
|
]
|
||||||
|
|
||||||
|
NSA_PREFILL_IMPL: _NSA_IMPL_T
|
||||||
|
NSA_DECODE_IMPL: _NSA_IMPL_T
|
||||||
|
|
||||||
|
|
||||||
|
class NativeSparseAttnBackend(AttentionBackend):
|
||||||
|
def __init__(self, model_runner: ModelRunner):
|
||||||
|
super().__init__()
|
||||||
|
self.forward_metadata: NSAMetadata
|
||||||
|
self.device = model_runner.device
|
||||||
|
assert isinstance(model_runner.page_size, int)
|
||||||
|
self.real_page_size = model_runner.page_size
|
||||||
|
self.num_splits = (
|
||||||
|
1 if model_runner.server_args.enable_deterministic_inference else 0
|
||||||
|
)
|
||||||
|
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
|
||||||
|
assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
|
||||||
|
self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
|
||||||
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
|
self.num_q_heads = (
|
||||||
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
|
)
|
||||||
|
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
|
||||||
|
|
||||||
|
assert model_runner.req_to_token_pool is not None
|
||||||
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||||
|
|
||||||
|
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
|
||||||
|
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
|
||||||
|
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
|
||||||
|
|
||||||
|
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
|
||||||
|
|
||||||
|
def get_device_int32_arange(self, l: int) -> torch.Tensor:
|
||||||
|
if l > len(self._arange_buf):
|
||||||
|
next_pow_of_2 = 1 << (l - 1).bit_length()
|
||||||
|
self._arange_buf = torch.arange(
|
||||||
|
next_pow_of_2, device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
return self._arange_buf[:l]
|
||||||
|
|
||||||
|
def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
|
||||||
|
page_size = self.real_page_size
|
||||||
|
if page_size == 1:
|
||||||
|
return page_table
|
||||||
|
max_seqlen_k = page_table.shape[1]
|
||||||
|
strided_indices = torch.arange(
|
||||||
|
0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
return page_table[:, strided_indices] // page_size
|
||||||
|
|
||||||
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
|
"""Init the metadata for a forward pass."""
|
||||||
|
batch_size = forward_batch.batch_size
|
||||||
|
device = forward_batch.seq_lens.device
|
||||||
|
|
||||||
|
assert (
|
||||||
|
forward_batch.spec_info is None
|
||||||
|
), "Spec decoding is not supported for NSA backend now"
|
||||||
|
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
||||||
|
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||||
|
assert forward_batch.seq_lens_cpu is not None
|
||||||
|
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
|
||||||
|
page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, :max_seqlen_k
|
||||||
|
]
|
||||||
|
|
||||||
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
|
extend_seq_lens_cpu = [1] * batch_size
|
||||||
|
max_seqlen_q = 1
|
||||||
|
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
|
||||||
|
seqlens_expanded = cache_seqlens_int32
|
||||||
|
elif forward_batch.forward_mode.is_extend():
|
||||||
|
assert (
|
||||||
|
forward_batch.extend_seq_lens_cpu is not None
|
||||||
|
and forward_batch.extend_seq_lens is not None
|
||||||
|
and forward_batch.extend_prefix_lens_cpu is not None
|
||||||
|
), "All of them must not be None"
|
||||||
|
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
|
||||||
|
assert forward_batch.extend_seq_lens is not None
|
||||||
|
if any(forward_batch.extend_prefix_lens_cpu):
|
||||||
|
max_seqlen_q = max(extend_seq_lens_cpu)
|
||||||
|
cu_seqlens_q = compute_cu_seqlens(
|
||||||
|
forward_batch.extend_seq_lens.to(torch.int32)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_seqlen_q = max_seqlen_k
|
||||||
|
cu_seqlens_q = cu_seqlens_k
|
||||||
|
seqlens_expanded = torch.cat(
|
||||||
|
[
|
||||||
|
torch.arange(
|
||||||
|
kv_len - qo_len + 1,
|
||||||
|
kv_len + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for qo_len, kv_len in zip(
|
||||||
|
forward_batch.extend_seq_lens_cpu,
|
||||||
|
forward_batch.seq_lens_cpu.tolist(),
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, f"Unsupported {forward_batch.forward_mode = }"
|
||||||
|
|
||||||
|
# 1D, expanded seqlens (1D means cheap to compute, so always compute it)
|
||||||
|
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||||
|
original_seq_lens=seqlens_expanded,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
|
)
|
||||||
|
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
||||||
|
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
||||||
|
|
||||||
|
metadata = NSAMetadata(
|
||||||
|
page_size=self.real_page_size,
|
||||||
|
cache_seqlens_int32=cache_seqlens_int32,
|
||||||
|
max_seq_len_q=max_seqlen_q,
|
||||||
|
max_seq_len_k=max_seqlen_k,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
page_table_1=page_table,
|
||||||
|
flashmla_metadata=(
|
||||||
|
self._compute_flashmla_metadata(
|
||||||
|
cache_seqlens=nsa_cache_seqlens_int32,
|
||||||
|
seq_len_q=1, # TODO handle MTP which is not 1
|
||||||
|
)
|
||||||
|
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
||||||
|
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
||||||
|
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
||||||
|
nsa_seqlens_expanded=seqlens_expanded,
|
||||||
|
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
|
||||||
|
real_page_table=self._transform_table_1_to_real(page_table),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
|
"""Initialize CUDA graph state for the attention backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_bs (int): Maximum batch size to support in CUDA graphs
|
||||||
|
|
||||||
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||||
|
to avoid memory allocations.
|
||||||
|
"""
|
||||||
|
self.decode_cuda_graph_metadata: Dict = {
|
||||||
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||||
|
"cu_seqlens_q": torch.arange(
|
||||||
|
0, max_bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"cu_seqlens_k": torch.zeros(
|
||||||
|
max_bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
# fake page_table for sparse_prefill
|
||||||
|
"page_table": torch.zeros(
|
||||||
|
max_bs,
|
||||||
|
self.max_context_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"flashmla_metadata": (
|
||||||
|
self._compute_flashmla_metadata(
|
||||||
|
cache_seqlens=torch.ones(
|
||||||
|
max_bs, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
seq_len_q=1, # TODO handle MTP which is not 1
|
||||||
|
)
|
||||||
|
if NSA_DECODE_IMPL == "flashmla_decode"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
num_tokens: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
|
):
|
||||||
|
"""Initialize forward metadata for capturing CUDA graph."""
|
||||||
|
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
||||||
|
assert (
|
||||||
|
spec_info is None
|
||||||
|
), "Speculative decoding is not supported for NSA backend now"
|
||||||
|
|
||||||
|
# Normal Decode
|
||||||
|
# Get sequence information
|
||||||
|
cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||||
|
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
||||||
|
|
||||||
|
# Use max context length for seq_len_k
|
||||||
|
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||||
|
max_seq_len_k = page_table_1.shape[1]
|
||||||
|
|
||||||
|
# Precompute page table
|
||||||
|
# Precompute cumulative sequence lengths
|
||||||
|
|
||||||
|
# NOTE(dark): this is always arange, since we are decoding
|
||||||
|
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
|
||||||
|
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
||||||
|
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
|
||||||
|
)
|
||||||
|
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
||||||
|
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
||||||
|
real_page_table = self._transform_table_1_to_real(page_table_1)
|
||||||
|
|
||||||
|
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||||
|
flashmla_metadata = self.decode_cuda_graph_metadata[
|
||||||
|
"flashmla_metadata"
|
||||||
|
].slice(slice(0, bs + 1))
|
||||||
|
flashmla_metadata.copy_(
|
||||||
|
self._compute_flashmla_metadata(
|
||||||
|
cache_seqlens=nsa_cache_seqlens_int32,
|
||||||
|
seq_len_q=1, # TODO handle MTP which is not 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
flashmla_metadata = None
|
||||||
|
|
||||||
|
metadata = NSAMetadata(
|
||||||
|
page_size=self.real_page_size,
|
||||||
|
cache_seqlens_int32=cache_seqlens_int32,
|
||||||
|
max_seq_len_q=1,
|
||||||
|
max_seq_len_k=max_seq_len_k,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
page_table_1=page_table_1,
|
||||||
|
flashmla_metadata=flashmla_metadata,
|
||||||
|
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
||||||
|
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
||||||
|
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
||||||
|
nsa_seqlens_expanded=cache_seqlens_int32,
|
||||||
|
real_page_table=real_page_table,
|
||||||
|
nsa_extend_seq_lens_list=[1] * bs,
|
||||||
|
)
|
||||||
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
|
self,
|
||||||
|
bs: int,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_sum: int,
|
||||||
|
encoder_lens: Optional[torch.Tensor],
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
|
out_cache_loc: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
"""Initialize forward metadata for replaying CUDA graph."""
|
||||||
|
assert seq_lens_cpu is not None
|
||||||
|
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
||||||
|
assert (
|
||||||
|
spec_info is None
|
||||||
|
), "Speculative decoding is not supported for NSA backend now"
|
||||||
|
seq_lens = seq_lens[:bs]
|
||||||
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||||
|
req_pool_indices = req_pool_indices[:bs]
|
||||||
|
|
||||||
|
# Normal Decode
|
||||||
|
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
|
||||||
|
max_len = int(seq_lens_cpu.max().item())
|
||||||
|
|
||||||
|
cache_seqlens = seq_lens.to(torch.int32)
|
||||||
|
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
||||||
|
metadata.cu_seqlens_k[1:].copy_(
|
||||||
|
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
||||||
|
)
|
||||||
|
page_indices = self.req_to_token[req_pool_indices, :max_len]
|
||||||
|
metadata.page_table_1[:, :max_len].copy_(page_indices)
|
||||||
|
assert (
|
||||||
|
metadata.nsa_cache_seqlens_int32 is not None
|
||||||
|
and metadata.nsa_cu_seqlens_k is not None
|
||||||
|
and self.nsa_index_topk is not None
|
||||||
|
)
|
||||||
|
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
|
||||||
|
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
|
||||||
|
metadata.nsa_cu_seqlens_k[1:].copy_(
|
||||||
|
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
|
||||||
|
)
|
||||||
|
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
|
||||||
|
|
||||||
|
assert self.real_page_size == metadata.page_size
|
||||||
|
if self.real_page_size > 1:
|
||||||
|
real_table = self._transform_table_1_to_real(page_indices)
|
||||||
|
new_len = real_table.shape[1]
|
||||||
|
metadata.real_page_table[:, :new_len].copy_(real_table)
|
||||||
|
else:
|
||||||
|
assert metadata.real_page_table is metadata.page_table_1
|
||||||
|
|
||||||
|
if NSA_DECODE_IMPL == "flashmla_decode":
|
||||||
|
metadata.flashmla_metadata.copy_(
|
||||||
|
self._compute_flashmla_metadata(
|
||||||
|
cache_seqlens=nsa_cache_seqlens,
|
||||||
|
seq_len_q=1, # TODO handle MTP which is not 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
|
def forward_extend(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
# For multi-head latent attention
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
topk_indices: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert (
|
||||||
|
not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
), "NSA backend doesn't support speculative decoding"
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
if save_kv_cache:
|
||||||
|
cache_loc = (
|
||||||
|
forward_batch.out_cache_loc
|
||||||
|
if not layer.is_cross_attention
|
||||||
|
else forward_batch.encoder_out_cache_loc
|
||||||
|
)
|
||||||
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
|
||||||
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
k_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = self.forward_metadata
|
||||||
|
causal = not layer.is_cross_attention
|
||||||
|
assert causal, "NSA is causal only"
|
||||||
|
|
||||||
|
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
# Do absorbed multi-latent attention
|
||||||
|
assert q_rope is not None
|
||||||
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
|
||||||
|
# when store in fp8 and compute in fp8, no need to convert dtype
|
||||||
|
if not (NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and NSA_KV_CACHE_STORE_FP8):
|
||||||
|
kv_cache = kv_cache.to(q.dtype)
|
||||||
|
|
||||||
|
if q_rope is not None:
|
||||||
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
q_rope = q_rope.view(
|
||||||
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||||
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||||
|
|
||||||
|
# NOTE(dark): here, we use page size = 1
|
||||||
|
|
||||||
|
if NSA_FUSE_TOPK:
|
||||||
|
page_table_1 = topk_indices
|
||||||
|
else:
|
||||||
|
assert metadata.nsa_extend_seq_lens_list is not None
|
||||||
|
page_table_1 = transform_index_page_table_prefill(
|
||||||
|
page_table=metadata.page_table_1,
|
||||||
|
topk_indices=topk_indices,
|
||||||
|
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
|
||||||
|
page_size=1,
|
||||||
|
)
|
||||||
|
# if NSA_PREFILL_IMPL == "tilelang":
|
||||||
|
# from sglang.srt.layers.attention.nsa.tilelang_kernel import (
|
||||||
|
# tilelang_sparse_fwd,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if q_rope is not None:
|
||||||
|
# q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
|
# return self._forward_tilelang(
|
||||||
|
# q_all=q_all,
|
||||||
|
# kv_cache=kv_cache,
|
||||||
|
# page_table_1=page_table_1,
|
||||||
|
# sm_scale=layer.scaling,
|
||||||
|
# v_head_dim=layer.v_head_dim,
|
||||||
|
# )
|
||||||
|
# elif NSA_PREFILL_IMPL == "flashmla_prefill":
|
||||||
|
|
||||||
|
|
||||||
|
# Skip tilelang dependencies
|
||||||
|
if NSA_PREFILL_IMPL == "tilelang" or NSA_PREFILL_IMPL == "flashmla_prefill":
|
||||||
|
if q_rope is not None:
|
||||||
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
|
return self._forward_flashmla_prefill(
|
||||||
|
q_all=q_all,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
page_table_1=page_table_1,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
v_head_dim=layer.v_head_dim,
|
||||||
|
)
|
||||||
|
elif NSA_PREFILL_IMPL == "flashmla_decode":
|
||||||
|
if q_rope is not None:
|
||||||
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
|
return self._forward_flashmla_decode(
|
||||||
|
q_all=q_all,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
v_head_dim=layer.v_head_dim,
|
||||||
|
# TODO optimize args
|
||||||
|
layer=layer,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
metadata=metadata,
|
||||||
|
topk_indices=topk_indices,
|
||||||
|
block_table=metadata.real_page_table,
|
||||||
|
)
|
||||||
|
elif NSA_PREFILL_IMPL == "fa3":
|
||||||
|
return self._forward_fa3(
|
||||||
|
q_rope=q_rope,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
v_head_dim=layer.v_head_dim,
|
||||||
|
q_nope=q_nope,
|
||||||
|
page_table=page_table_1,
|
||||||
|
cache_seqlens=metadata.nsa_cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
|
||||||
|
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
|
||||||
|
max_seqlen_q=metadata.nsa_max_seqlen_q,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logit_cap=layer.logit_cap,
|
||||||
|
page_size=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
|
||||||
|
|
||||||
|
def forward_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer: RadixAttention,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
save_kv_cache=True,
|
||||||
|
# For multi-head latent attention
|
||||||
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
topk_indices: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
if save_kv_cache:
|
||||||
|
cache_loc = (
|
||||||
|
forward_batch.out_cache_loc
|
||||||
|
if not layer.is_cross_attention
|
||||||
|
else forward_batch.encoder_out_cache_loc
|
||||||
|
)
|
||||||
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
|
||||||
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
k_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = self.forward_metadata
|
||||||
|
causal = not layer.is_cross_attention
|
||||||
|
assert causal, "NSA is causal only"
|
||||||
|
|
||||||
|
# Do absorbed multi-latent attention
|
||||||
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
if q_rope is not None:
|
||||||
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
q_rope = q_rope.view(
|
||||||
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||||
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||||
|
|
||||||
|
if NSA_FUSE_TOPK:
|
||||||
|
page_table_1 = topk_indices
|
||||||
|
else:
|
||||||
|
page_table_1 = transform_index_page_table_decode(
|
||||||
|
page_table=metadata.page_table_1,
|
||||||
|
topk_indices=topk_indices,
|
||||||
|
page_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
if NSA_DECODE_IMPL == "flashmla_prefill":
|
||||||
|
if q_rope is not None:
|
||||||
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
|
return self._forward_flashmla_prefill(
|
||||||
|
q_all=q_all,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
page_table_1=page_table_1,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
v_head_dim=layer.v_head_dim,
|
||||||
|
)
|
||||||
|
elif NSA_DECODE_IMPL == "flashmla_decode":
|
||||||
|
if q_rope is not None:
|
||||||
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
|
return self._forward_flashmla_decode(
|
||||||
|
q_all=q_all,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
v_head_dim=layer.v_head_dim,
|
||||||
|
# TODO optimize args
|
||||||
|
layer=layer,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
metadata=metadata,
|
||||||
|
topk_indices=topk_indices,
|
||||||
|
block_table=metadata.real_page_table,
|
||||||
|
)
|
||||||
|
elif NSA_DECODE_IMPL == "tilelang":
|
||||||
|
if q_rope is not None:
|
||||||
|
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
|
return self._forward_tilelang(
|
||||||
|
q_all=q_all,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
page_table_1=page_table_1,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
v_head_dim=layer.v_head_dim,
|
||||||
|
)
|
||||||
|
elif NSA_DECODE_IMPL == "fa3":
|
||||||
|
return self._forward_fa3(
|
||||||
|
q_rope=q_rope,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
v_head_dim=layer.v_head_dim,
|
||||||
|
q_nope=q_nope,
|
||||||
|
page_table=page_table_1,
|
||||||
|
cache_seqlens=metadata.nsa_cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
|
||||||
|
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
|
||||||
|
max_seqlen_q=metadata.nsa_max_seqlen_q,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logit_cap=layer.logit_cap,
|
||||||
|
page_size=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, f"Unsupported {NSA_DECODE_IMPL = }"
|
||||||
|
|
||||||
|
def _forward_fa3(
|
||||||
|
self,
|
||||||
|
q_rope: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
v_head_dim: int,
|
||||||
|
q_nope: torch.Tensor,
|
||||||
|
page_table: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
sm_scale: float,
|
||||||
|
logit_cap: float,
|
||||||
|
page_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
k_rope_cache = kv_cache[:, :, v_head_dim:]
|
||||||
|
c_kv_cache = kv_cache[:, :, :v_head_dim]
|
||||||
|
qk_rope_dim = k_rope_cache.shape[-1]
|
||||||
|
k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
|
||||||
|
c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
|
||||||
|
o = flash_attn_with_kvcache(
|
||||||
|
q=q_rope,
|
||||||
|
k_cache=k_rope_cache,
|
||||||
|
v_cache=c_kv_cache,
|
||||||
|
qv=q_nope,
|
||||||
|
page_table=page_table,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
softmax_scale=sm_scale,
|
||||||
|
causal=True,
|
||||||
|
softcap=logit_cap,
|
||||||
|
return_softmax_lse=False,
|
||||||
|
num_splits=self.num_splits,
|
||||||
|
)
|
||||||
|
return o # type: ignore
|
||||||
|
|
||||||
|
def _forward_flashmla_prefill(
|
||||||
|
self,
|
||||||
|
q_all: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
v_head_dim: int,
|
||||||
|
page_table_1: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
#from flash_mla import flash_mla_sparse_fwd
|
||||||
|
from sglang.srt.layers.attention.native_mla import native_mla_sparse_fwd
|
||||||
|
_, _, o = native_mla_sparse_fwd(
|
||||||
|
q=q_all,
|
||||||
|
kv=kv_cache,
|
||||||
|
indices=page_table_1.unsqueeze(1),
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
d_v=v_head_dim,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def _forward_flashmla_decode(
|
||||||
|
self,
|
||||||
|
q_all: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
v_head_dim: int,
|
||||||
|
sm_scale: float,
|
||||||
|
layer,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
metadata: NSAMetadata,
|
||||||
|
topk_indices,
|
||||||
|
block_table,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
#from flash_mla import flash_mla_with_kvcache
|
||||||
|
from sglang.srt.layers.attention.native_mla import native_mla_with_kvcache
|
||||||
|
cache_seqlens = metadata.nsa_cache_seqlens_int32
|
||||||
|
|
||||||
|
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
|
||||||
|
q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
|
||||||
|
assert self.real_page_size == 64, "only page size 64 is supported"
|
||||||
|
|
||||||
|
if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not NSA_KV_CACHE_STORE_FP8:
|
||||||
|
# inefficiently quantize the whole cache
|
||||||
|
kv_cache = quantize_k_cache(kv_cache)
|
||||||
|
|
||||||
|
o, _ = native_mla_with_kvcache(
|
||||||
|
q=q_all,
|
||||||
|
k_cache=kv_cache,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
head_dim_v=v_head_dim,
|
||||||
|
tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
|
||||||
|
num_splits=metadata.flashmla_metadata.num_splits,
|
||||||
|
softmax_scale=sm_scale,
|
||||||
|
# TODO improve
|
||||||
|
indices=_compute_indices_in_kvcache(
|
||||||
|
block_table=block_table,
|
||||||
|
topk_indices=topk_indices.to(torch.int32),
|
||||||
|
page_size=self.real_page_size,
|
||||||
|
nsa_index_topk=self.nsa_index_topk,
|
||||||
|
),
|
||||||
|
# doc says it is not used, but if pass in None then error
|
||||||
|
block_table=torch.empty(
|
||||||
|
(q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
|
||||||
|
),
|
||||||
|
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO shape correct?
|
||||||
|
return o
|
||||||
|
|
||||||
|
def _forward_tilelang(
|
||||||
|
self,
|
||||||
|
q_all: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
v_head_dim: int,
|
||||||
|
page_table_1: torch.Tensor,
|
||||||
|
sm_scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
|
||||||
|
|
||||||
|
return tilelang_sparse_fwd(
|
||||||
|
q=q_all,
|
||||||
|
kv=kv_cache,
|
||||||
|
indices=page_table_1.unsqueeze(1),
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
d_v=v_head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
"""Get the fill value for sequence length in CUDA graph."""
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def get_indexer_metadata(
|
||||||
|
self, layer_id: int, forward_batch: ForwardBatch
|
||||||
|
) -> NSAIndexerMetadata:
|
||||||
|
return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
|
||||||
|
|
||||||
|
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
|
||||||
|
from flash_mla import get_mla_metadata
|
||||||
|
|
||||||
|
flashmla_metadata, num_splits = get_mla_metadata(
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
|
||||||
|
# but the name looks like need seq_len_q?
|
||||||
|
num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
|
||||||
|
num_heads_k=1,
|
||||||
|
num_heads_q=self.num_q_heads,
|
||||||
|
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
||||||
|
topk=self.nsa_index_topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
return NSAFlashMLAMetadata(
|
||||||
|
flashmla_metadata=flashmla_metadata,
|
||||||
|
num_splits=num_splits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO speedup
|
||||||
|
def _compute_indices_in_kvcache(block_table, topk_indices, page_size, nsa_index_topk):
|
||||||
|
topk_indices_safe = topk_indices.masked_fill(topk_indices == -1, 0)
|
||||||
|
|
||||||
|
idx0 = torch.arange(block_table.size(0), device=topk_indices_safe.device).unsqueeze(
|
||||||
|
1
|
||||||
|
)
|
||||||
|
block_idx = block_table[idx0, topk_indices_safe // page_size]
|
||||||
|
offset = topk_indices_safe % page_size
|
||||||
|
indices_in_kvcache = block_idx * page_size + offset
|
||||||
|
|
||||||
|
# the kernel requires invalid entry to be -1
|
||||||
|
assert indices_in_kvcache.shape == topk_indices.shape
|
||||||
|
indices_in_kvcache[topk_indices == -1] = -1
|
||||||
|
|
||||||
|
# return: (batch_size, seqlen_q_ori, topk)
|
||||||
|
indices_in_kvcache = indices_in_kvcache[:, None, :]
|
||||||
|
|
||||||
|
indices_in_kvcache = torch.nn.functional.pad(
|
||||||
|
indices_in_kvcache,
|
||||||
|
(0, nsa_index_topk - indices_in_kvcache.shape[-1]),
|
||||||
|
"constant",
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
assert indices_in_kvcache.shape[-1] == nsa_index_topk
|
||||||
|
|
||||||
|
return indices_in_kvcache
|
||||||
@@ -127,8 +127,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
"disable_chunked_prefix_cache"
|
"disable_chunked_prefix_cache"
|
||||||
]
|
]
|
||||||
|
|
||||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
|
||||||
|
|
||||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
||||||
@@ -219,7 +217,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
"""Initialize metadata for CUDA graph capture."""
|
"""Initialize metadata for CUDA graph capture."""
|
||||||
|
|
||||||
# Delegate to parent for non-decode modes.
|
# Delegate to parent for non-decode modes.
|
||||||
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
if not forward_mode.is_decode_or_idle():
|
||||||
return super().init_forward_metadata_capture_cuda_graph(
|
return super().init_forward_metadata_capture_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -230,9 +228,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
spec_info,
|
spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
if forward_mode.is_target_verify():
|
|
||||||
seq_lens = seq_lens + self.num_draft_tokens
|
|
||||||
|
|
||||||
# Custom fast-path for decode/idle.
|
# Custom fast-path for decode/idle.
|
||||||
# Capture with full width so future longer sequences are safe during replay
|
# Capture with full width so future longer sequences are safe during replay
|
||||||
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
||||||
@@ -275,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
):
|
):
|
||||||
"""Replay CUDA graph with new inputs."""
|
"""Replay CUDA graph with new inputs."""
|
||||||
# Delegate to parent for non-decode modes.
|
# Delegate to parent for non-decode modes.
|
||||||
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
if not forward_mode.is_decode_or_idle():
|
||||||
return super().init_forward_metadata_replay_cuda_graph(
|
return super().init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -287,10 +282,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
seq_lens_cpu,
|
seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
if forward_mode.is_target_verify():
|
|
||||||
seq_lens = seq_lens + self.num_draft_tokens
|
|
||||||
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
|
|
||||||
|
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
|
|
||||||
# Update block indices for new sequences.
|
# Update block indices for new sequences.
|
||||||
@@ -341,10 +332,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
cum_seq_lens_q,
|
cum_seq_lens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
)
|
)
|
||||||
elif (
|
elif forward_batch.forward_mode.is_decode_or_idle():
|
||||||
forward_batch.forward_mode.is_decode_or_idle()
|
|
||||||
or forward_batch.forward_mode.is_target_verify()
|
|
||||||
):
|
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
|
|
||||||
# Get maximum sequence length.
|
# Get maximum sequence length.
|
||||||
@@ -353,19 +341,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
else:
|
else:
|
||||||
max_seq = forward_batch.seq_lens.max().item()
|
max_seq = forward_batch.seq_lens.max().item()
|
||||||
|
|
||||||
seq_lens = forward_batch.seq_lens
|
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_target_verify():
|
|
||||||
max_seq = max_seq + self.num_draft_tokens
|
|
||||||
seq_lens = seq_lens + self.num_draft_tokens
|
|
||||||
|
|
||||||
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
||||||
block_kv_indices = self._create_block_kv_indices(
|
block_kv_indices = self._create_block_kv_indices(
|
||||||
bs,
|
bs,
|
||||||
max_seqlen_pad,
|
max_seqlen_pad,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
seq_lens,
|
forward_batch.seq_lens,
|
||||||
seq_lens.device,
|
forward_batch.seq_lens.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_seq_len_val = int(max_seq)
|
max_seq_len_val = int(max_seq)
|
||||||
@@ -505,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
q_rope_reshaped = q_rope.view(
|
q_rope_reshaped = q_rope.view(
|
||||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
)
|
)
|
||||||
query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
|
if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64:
|
||||||
|
query = concat_mla_absorb_q(q_nope, q_rope_reshaped)
|
||||||
|
else:
|
||||||
|
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
||||||
else:
|
else:
|
||||||
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
||||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
@@ -568,134 +553,84 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
):
|
||||||
if forward_batch.forward_mode.is_draft_extend():
|
if (
|
||||||
|
forward_batch.forward_mode.is_target_verify()
|
||||||
|
or forward_batch.forward_mode.is_draft_extend()
|
||||||
|
):
|
||||||
|
return super().forward_extend(
|
||||||
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
|
)
|
||||||
|
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
|
||||||
|
if forward_batch.attn_attend_prefix_cache is None:
|
||||||
return super().forward_extend(
|
return super().forward_extend(
|
||||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save KV cache if requested
|
if not forward_batch.attn_attend_prefix_cache:
|
||||||
if save_kv_cache:
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
assert (
|
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
||||||
k is not None and k_rope is not None
|
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
||||||
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
|
||||||
layer, forward_batch.out_cache_loc, k, k_rope
|
|
||||||
)
|
|
||||||
|
|
||||||
if q_rope is not None:
|
|
||||||
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
|
||||||
q_rope = q_rope.view(
|
|
||||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
|
||||||
)
|
|
||||||
q = _concat_mla_absorb_q_general(q, q_rope)
|
|
||||||
|
|
||||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
|
|
||||||
if k_rope is not None:
|
|
||||||
k = torch.cat([k, k_rope], dim=-1)
|
|
||||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
|
||||||
|
|
||||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_target_verify():
|
|
||||||
metadata = (
|
|
||||||
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
|
||||||
or self.forward_decode_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
|
|
||||||
bs = forward_batch.batch_size
|
|
||||||
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
|
|
||||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
|
||||||
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
|
||||||
|
|
||||||
q_scale = 1.0
|
|
||||||
k_scale = (
|
|
||||||
layer.k_scale_float
|
|
||||||
if getattr(layer, "k_scale_float", None) is not None
|
|
||||||
else 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
bmm1_scale = q_scale * k_scale * layer.scaling
|
|
||||||
|
|
||||||
seq_lens = (
|
|
||||||
forward_batch.seq_lens.to(torch.int32)
|
|
||||||
+ forward_batch.spec_info.draft_token_num
|
|
||||||
)
|
|
||||||
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
|
|
||||||
|
|
||||||
# TODO may use `mla_rope_quantize_fp8` fusion
|
|
||||||
q = q.to(self.data_type)
|
|
||||||
assert kv_cache.dtype == self.data_type
|
|
||||||
|
|
||||||
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
|
||||||
query=q,
|
|
||||||
kv_cache=kv_cache,
|
|
||||||
workspace_buffer=self.workspace_buffer,
|
|
||||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
||||||
kv_lora_rank=self.kv_lora_rank,
|
|
||||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
||||||
block_tables=metadata.block_kv_indices,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
bmm1_scale=bmm1_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reshape output directly without slicing
|
|
||||||
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
||||||
return output
|
|
||||||
|
|
||||||
if forward_batch.attn_attend_prefix_cache:
|
|
||||||
# MHA for chunked prefix kv cache when running model with MLA
|
|
||||||
assert forward_batch.prefix_chunk_idx is not None
|
|
||||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
|
||||||
assert q_rope is None
|
|
||||||
assert k_rope is None
|
|
||||||
chunk_idx = forward_batch.prefix_chunk_idx
|
|
||||||
|
|
||||||
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
|
||||||
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
|
||||||
query=q,
|
query=q,
|
||||||
key=k,
|
key=k,
|
||||||
value=v,
|
value=v,
|
||||||
workspace_buffer=self.workspace_buffer,
|
workspace_buffer=self.workspace_buffer,
|
||||||
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
seq_lens=self.forward_prefill_metadata.seq_lens,
|
||||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||||
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
||||||
bmm1_scale=layer.scaling,
|
bmm1_scale=layer.scaling,
|
||||||
bmm2_scale=1.0,
|
bmm2_scale=1.0,
|
||||||
o_sf_scale=-1.0,
|
o_sf_scale=1.0,
|
||||||
batch_size=forward_batch.batch_size,
|
batch_size=forward_batch.batch_size,
|
||||||
window_left=-1,
|
window_left=-1,
|
||||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||||
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
||||||
enable_pdl=False,
|
enable_pdl=False,
|
||||||
is_causal=False,
|
is_causal=True,
|
||||||
return_lse=True,
|
return_lse=forward_batch.mha_return_lse,
|
||||||
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if not (
|
||||||
|
forward_batch.attn_attend_prefix_cache is not None
|
||||||
|
and forward_batch.mha_return_lse
|
||||||
|
):
|
||||||
|
output = super().forward_extend(
|
||||||
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# MHA for chunked prefix kv cache when running model with MLA
|
||||||
|
assert forward_batch.prefix_chunk_idx is not None
|
||||||
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||||
|
assert q_rope is None
|
||||||
|
assert k_rope is None
|
||||||
|
chunk_idx = forward_batch.prefix_chunk_idx
|
||||||
|
|
||||||
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
query=q,
|
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
|
||||||
key=k,
|
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
|
||||||
value=v,
|
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
||||||
workspace_buffer=self.workspace_buffer,
|
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||||
seq_lens=self.forward_prefill_metadata.seq_lens,
|
query=q,
|
||||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
key=k,
|
||||||
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
value=v,
|
||||||
bmm1_scale=layer.scaling,
|
workspace_buffer=self.workspace_buffer,
|
||||||
bmm2_scale=1.0,
|
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
||||||
o_sf_scale=1.0,
|
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||||
batch_size=forward_batch.batch_size,
|
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||||||
window_left=-1,
|
bmm1_scale=layer.scaling,
|
||||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
bmm2_scale=1.0,
|
||||||
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
o_sf_scale=-1.0,
|
||||||
enable_pdl=False,
|
batch_size=forward_batch.batch_size,
|
||||||
is_causal=True,
|
window_left=-1,
|
||||||
return_lse=forward_batch.mha_return_lse,
|
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||||
)
|
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||||||
|
enable_pdl=False,
|
||||||
|
is_causal=False,
|
||||||
|
return_lse=True,
|
||||||
|
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||||
@@ -713,10 +648,3 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
|||||||
kv_indptr_buf=self.kv_indptr[i],
|
kv_indptr_buf=self.kv_indptr[i],
|
||||||
q_indptr_decode_buf=self.q_indptr_decode,
|
q_indptr_decode_buf=self.q_indptr_decode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _concat_mla_absorb_q_general(q_nope, q_rope):
|
|
||||||
if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
|
|
||||||
return concat_mla_absorb_q(q_nope, q_rope)
|
|
||||||
else:
|
|
||||||
return torch.cat([q_nope, q_rope], dim=-1)
|
|
||||||
|
|||||||
@@ -16,19 +16,14 @@ from sglang.srt.utils import (
|
|||||||
get_device_capability,
|
get_device_capability,
|
||||||
is_blackwell,
|
is_blackwell,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_npu,
|
|
||||||
print_info_once,
|
print_info_once,
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_npu = is_npu()
|
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
if _is_npu:
|
|
||||||
import torch_npu
|
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
split_tensor_along_last_dim,
|
split_tensor_along_last_dim,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
@@ -336,63 +331,10 @@ class VisionFlash3Attention(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class VisionAscendAttention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if not _is_npu:
|
|
||||||
raise Exception("VisionAscendAttention is only available for ascend npu")
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
|
||||||
bsz: int,
|
|
||||||
seq_len: int,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
cu_seqlens: [b]
|
|
||||||
Returns:
|
|
||||||
[b * s, h, head_size]
|
|
||||||
"""
|
|
||||||
if cu_seqlens is None:
|
|
||||||
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
|
||||||
|
|
||||||
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
||||||
if seq_lens.is_npu:
|
|
||||||
# cu_seqlens must be on cpu because of operator restriction
|
|
||||||
seq_lens = seq_lens.to("cpu")
|
|
||||||
_, num_heads, head_size = q.shape
|
|
||||||
num_kv_heads = k.shape[1]
|
|
||||||
output = torch.empty_like(q)
|
|
||||||
|
|
||||||
# operator requires pta version >= 2.5.1
|
|
||||||
torch_npu._npu_flash_attention_unpad(
|
|
||||||
query=q,
|
|
||||||
key=k,
|
|
||||||
value=v,
|
|
||||||
seq_len=seq_lens.to(torch.int32),
|
|
||||||
scale_value=head_size**-0.5,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
out=output,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
QKV_BACKEND_IMPL = {
|
QKV_BACKEND_IMPL = {
|
||||||
"triton_attn": VisionTritonAttention,
|
"triton_attn": VisionTritonAttention,
|
||||||
"sdpa": VisionSdpaAttention,
|
"sdpa": VisionSdpaAttention,
|
||||||
"fa3": VisionFlash3Attention,
|
"fa3": VisionFlash3Attention,
|
||||||
"ascend_attn": VisionAscendAttention,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ from sglang.srt.utils import (
|
|||||||
is_hip,
|
is_hip,
|
||||||
is_sm90_supported,
|
is_sm90_supported,
|
||||||
is_sm100_supported,
|
is_sm100_supported,
|
||||||
prepare_weight_cache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_flashinfer_available = is_flashinfer_available()
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
@@ -276,11 +275,7 @@ class LayerCommunicator:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
cache=None,
|
|
||||||
):
|
):
|
||||||
if cache is not None:
|
|
||||||
self._context.cache = cache
|
|
||||||
|
|
||||||
return self._communicate_with_all_reduce_and_layer_norm_fn(
|
return self._communicate_with_all_reduce_and_layer_norm_fn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
@@ -354,7 +349,6 @@ class CommunicateContext:
|
|||||||
attn_tp_size: int
|
attn_tp_size: int
|
||||||
attn_dp_size: int
|
attn_dp_size: int
|
||||||
tp_size: int
|
tp_size: int
|
||||||
cache = None
|
|
||||||
|
|
||||||
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
||||||
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
||||||
@@ -539,8 +533,6 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||||
if context.cache is not None:
|
|
||||||
_ = prepare_weight_cache(hidden_states, context.cache)
|
|
||||||
hidden_states, residual = layernorm(hidden_states, residual)
|
hidden_states, residual = layernorm(hidden_states, residual)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|||||||
@@ -187,9 +187,7 @@ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
|
|||||||
|
|
||||||
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
||||||
assert len(x.shape) == 2
|
assert len(x.shape) == 2
|
||||||
assert (
|
assert x.shape == residual.shape and x.dtype == residual.dtype
|
||||||
x.shape == residual.shape and x.dtype == residual.dtype
|
|
||||||
), f"{x.shape=} {residual.shape=} {x.dtype=} {residual.dtype=}"
|
|
||||||
output, mid = torch.empty_like(x), torch.empty_like(x)
|
output, mid = torch.empty_like(x), torch.empty_like(x)
|
||||||
bs, hidden_dim = x.shape
|
bs, hidden_dim = x.shape
|
||||||
if autotune:
|
if autotune:
|
||||||
|
|||||||
@@ -136,21 +136,21 @@ class RMSNorm(CustomOp):
|
|||||||
# NOTE: Remove this if aiter kernel supports discontinuous input
|
# NOTE: Remove this if aiter kernel supports discontinuous input
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
if _vllm_version < Version("0.9"):
|
#if _vllm_version < Version("0.9"):
|
||||||
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||||
return x, residual
|
return x, residual
|
||||||
else:
|
# else:
|
||||||
residual_out = torch.empty_like(x)
|
# residual_out = torch.empty_like(x)
|
||||||
output = torch.empty_like(x)
|
# output = torch.empty_like(x)
|
||||||
fused_add_rms_norm(
|
# fused_add_rms_norm(
|
||||||
output,
|
# output,
|
||||||
x,
|
# x,
|
||||||
residual_out,
|
# residual_out,
|
||||||
residual,
|
# residual,
|
||||||
self.weight.data,
|
# self.weight.data,
|
||||||
self.variance_epsilon,
|
# self.variance_epsilon,
|
||||||
)
|
# )
|
||||||
return output, residual_out
|
# return output, residual_out
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ from sglang.srt.layers.parameter import (
|
|||||||
_ColumnvLLMParameter,
|
_ColumnvLLMParameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.utils import pad_or_narrow_weight
|
|
||||||
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -626,16 +625,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
# bitsandbytes loads the weights of the specific portion
|
# bitsandbytes loads the weights of the specific portion
|
||||||
# no need to narrow here
|
# no need to narrow here
|
||||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
loaded_weight = loaded_weight.narrow(
|
||||||
end_idx = start_idx + shard_size
|
output_dim, start_idx, shard_size
|
||||||
if end_idx > loaded_weight.shape[output_dim]:
|
)
|
||||||
loaded_weight = pad_or_narrow_weight(
|
|
||||||
loaded_weight, output_dim, start_idx, shard_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loaded_weight = loaded_weight.narrow(
|
|
||||||
output_dim, start_idx, shard_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Special case for AQLM codebooks.
|
# Special case for AQLM codebooks.
|
||||||
elif is_metadata:
|
elif is_metadata:
|
||||||
@@ -1310,16 +1302,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
shard_size,
|
shard_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
||||||
end_idx = start_idx + shard_size
|
|
||||||
if end_idx > loaded_weight.shape[input_dim]:
|
|
||||||
loaded_weight = pad_or_narrow_weight(
|
|
||||||
loaded_weight, input_dim, start_idx, shard_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loaded_weight = loaded_weight.narrow(
|
|
||||||
input_dim, start_idx, shard_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Special case for loading scales off disk, which often do not
|
# Special case for loading scales off disk, which often do not
|
||||||
# have a shape (such as in the case of AutoFP8).
|
# have a shape (such as in the case of AutoFP8).
|
||||||
|
|||||||
@@ -220,7 +220,6 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.logit_scale = logit_scale
|
self.logit_scale = logit_scale
|
||||||
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
||||||
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
|
|
||||||
if self.use_attn_tp_group:
|
if self.use_attn_tp_group:
|
||||||
self.attn_tp_size = get_attention_tp_size()
|
self.attn_tp_size = get_attention_tp_size()
|
||||||
self.do_tensor_parallel_all_gather = (
|
self.do_tensor_parallel_all_gather = (
|
||||||
@@ -462,11 +461,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||||
|
|
||||||
if hasattr(lm_head, "weight"):
|
if hasattr(lm_head, "weight"):
|
||||||
if self.use_fp32_lm_head:
|
if use_intel_amx_backend(lm_head):
|
||||||
logits = torch.matmul(
|
|
||||||
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
|
|
||||||
)
|
|
||||||
elif use_intel_amx_backend(lm_head):
|
|
||||||
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
||||||
hidden_states.to(lm_head.weight.dtype),
|
hidden_states.to(lm_head.weight.dtype),
|
||||||
lm_head.weight,
|
lm_head.weight,
|
||||||
@@ -480,15 +475,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# GGUF models
|
# GGUF models
|
||||||
# TODO: use weight_packed_linear for GGUF models
|
# TODO: use weight_packed_linear for GGUF models
|
||||||
if self.use_fp32_lm_head:
|
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
logits = lm_head.quant_method.apply(
|
|
||||||
lm_head, hidden_states.to(torch.float32), embedding_bias
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = lm_head.quant_method.apply(
|
|
||||||
lm_head, hidden_states, embedding_bias
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.logit_scale is not None:
|
if self.logit_scale is not None:
|
||||||
logits.mul_(self.logit_scale)
|
logits.mul_(self.logit_scale)
|
||||||
|
|||||||
@@ -789,45 +789,69 @@ class DeepEPMoE(EPMoE):
|
|||||||
if isinstance(hidden_states, tuple):
|
if isinstance(hidden_states, tuple):
|
||||||
per_token_scale = hidden_states[1]
|
per_token_scale = hidden_states[1]
|
||||||
hidden_states = hidden_states[0]
|
hidden_states = hidden_states[0]
|
||||||
else:
|
|
||||||
# dynamic quant
|
|
||||||
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
|
||||||
hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
||||||
hidden_states.device
|
hidden_states.device
|
||||||
)
|
)
|
||||||
|
if self.w13_weight.dtype != torch.int8:
|
||||||
|
# gmm1: gate_up_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||||
|
# per_token_scale=[per_token_scale],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||||
|
# gmm2: down_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w2_weight.permute(0, 2, 1)],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
|
||||||
|
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
||||||
|
hidden_states
|
||||||
|
)
|
||||||
|
# gmm1: gate_up_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w13_weight],
|
||||||
|
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||||
|
per_token_scale=[per_token_scale],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
|
||||||
# gmm1: gate_up_proj
|
# act_fn: swiglu
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||||
x=[hidden_states],
|
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||||
weight=[self.w13_weight],
|
hidden_states
|
||||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
)
|
||||||
per_token_scale=[per_token_scale],
|
|
||||||
split_item=2,
|
|
||||||
group_list_type=group_list_type,
|
|
||||||
group_type=0,
|
|
||||||
group_list=group_list,
|
|
||||||
output_dtype=output_dtype,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# act_fn: swiglu
|
# gmm2: down_proj
|
||||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
x=[hidden_states],
|
||||||
|
weight=[self.w2_weight],
|
||||||
# gmm2: down_proj
|
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
per_token_scale=[swiglu_out_scale],
|
||||||
x=[hidden_states],
|
split_item=2,
|
||||||
weight=[self.w2_weight],
|
group_list_type=group_list_type,
|
||||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
group_type=0,
|
||||||
per_token_scale=[swiglu_out_scale],
|
group_list=group_list,
|
||||||
split_item=2,
|
output_dtype=output_dtype,
|
||||||
group_list_type=group_list_type,
|
)[0]
|
||||||
group_type=0,
|
|
||||||
group_list=group_list,
|
|
||||||
output_dtype=output_dtype,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -836,47 +860,72 @@ class DeepEPMoE(EPMoE):
|
|||||||
assert isinstance(dispatch_output, DeepEPLLOutput)
|
assert isinstance(dispatch_output, DeepEPLLOutput)
|
||||||
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
||||||
|
|
||||||
per_token_scale = hidden_states[1]
|
if isinstance(hidden_states, tuple):
|
||||||
hidden_states = hidden_states[0]
|
per_token_scale = hidden_states[1]
|
||||||
|
hidden_states = hidden_states[0]
|
||||||
|
|
||||||
group_list = group_list.to(torch.int64)
|
group_list = group_list.to(torch.int64)
|
||||||
|
|
||||||
# gmm1: gate_up_proj
|
if self.w13_weight.dtype != torch.int8:
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
# gmm1: gate_up_proj
|
||||||
x=[hidden_states],
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
weight=[self.w13_weight],
|
x=[hidden_states],
|
||||||
split_item=2,
|
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||||
group_list_type=group_list_type,
|
# per_token_scale=[per_token_scale],
|
||||||
group_type=0,
|
split_item=2,
|
||||||
group_list=group_list,
|
group_list_type=group_list_type,
|
||||||
output_dtype=torch.int32,
|
group_type=0,
|
||||||
)[0]
|
group_list=group_list,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||||
|
# gmm2: down_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w2_weight.permute(0, 2, 1)],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
# gmm1: gate_up_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w13_weight],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=group_list,
|
||||||
|
output_dtype=torch.int32,
|
||||||
|
)[0]
|
||||||
|
|
||||||
# act_fn: swiglu
|
# act_fn: swiglu
|
||||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||||
activation_scale=per_token_scale,
|
activation_scale=per_token_scale,
|
||||||
bias=None,
|
bias=None,
|
||||||
quant_scale=None,
|
quant_scale=None,
|
||||||
quant_offset=None,
|
quant_offset=None,
|
||||||
group_index=group_list,
|
group_index=group_list,
|
||||||
activate_left=True,
|
activate_left=True,
|
||||||
quant_mode=1,
|
quant_mode=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# gmm2: down_proj
|
# gmm2: down_proj
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[self.w2_weight],
|
weight=[self.w2_weight],
|
||||||
scale=[self.w2_weight_scale.to(output_dtype)],
|
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||||
per_token_scale=[swiglu_out_scale],
|
per_token_scale=[swiglu_out_scale],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
output_dtype=output_dtype,
|
output_dtype=output_dtype,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
@@ -1,146 +0,0 @@
|
|||||||
{
|
|
||||||
"1": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 4
|
|
||||||
},
|
|
||||||
"2": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"4": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 64,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 4
|
|
||||||
},
|
|
||||||
"8": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"16": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"24": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"32": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"48": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"64": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"96": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"128": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"256": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"512": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"1024": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"1536": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"2048": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"3072": {
|
|
||||||
"BLOCK_SIZE_M": 128,
|
|
||||||
"BLOCK_SIZE_N": 64,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 16,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"4096": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
{
|
|
||||||
"1": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"2": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"4": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 4
|
|
||||||
},
|
|
||||||
"8": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"16": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"24": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"32": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"48": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"64": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"96": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"128": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 2
|
|
||||||
},
|
|
||||||
"256": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 2
|
|
||||||
},
|
|
||||||
"512": {
|
|
||||||
"BLOCK_SIZE_M": 16,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"1024": {
|
|
||||||
"BLOCK_SIZE_M": 32,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"1536": {
|
|
||||||
"BLOCK_SIZE_M": 32,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 1,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"2048": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"3072": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 32,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 3
|
|
||||||
},
|
|
||||||
"4096": {
|
|
||||||
"BLOCK_SIZE_M": 64,
|
|
||||||
"BLOCK_SIZE_N": 128,
|
|
||||||
"BLOCK_SIZE_K": 128,
|
|
||||||
"GROUP_SIZE_M": 64,
|
|
||||||
"num_warps": 4,
|
|
||||||
"num_stages": 4
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -61,7 +61,7 @@ def inplace_fused_experts(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -79,6 +79,8 @@ def inplace_fused_experts(
|
|||||||
gemm1_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
gemm1_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if isinstance(activation, int):
|
||||||
|
activation = "silu" if activation == 0 else "gelu"
|
||||||
fused_experts_impl(
|
fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
@@ -117,7 +119,7 @@ def inplace_fused_experts_fake(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -154,7 +156,7 @@ def outplace_fused_experts(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -173,6 +175,8 @@ def outplace_fused_experts(
|
|||||||
gemm1_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
gemm1_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if isinstance(activation, int):
|
||||||
|
activation = "silu" if activation == 0 else "gelu"
|
||||||
return fused_experts_impl(
|
return fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
@@ -211,7 +215,7 @@ def outplace_fused_experts_fake(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -263,6 +267,13 @@ def fused_experts(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
act_id = (
|
||||||
|
0 if (
|
||||||
|
moe_runner_config.activation == 0
|
||||||
|
or (isinstance(moe_runner_config.activation, str)
|
||||||
|
and moe_runner_config.activation.lower() == "silu")
|
||||||
|
) else 1
|
||||||
|
)
|
||||||
if moe_runner_config.inplace:
|
if moe_runner_config.inplace:
|
||||||
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
|
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
|
||||||
torch.ops.sglang.inplace_fused_experts(
|
torch.ops.sglang.inplace_fused_experts(
|
||||||
@@ -273,7 +284,7 @@ def fused_experts(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
b1,
|
b1,
|
||||||
b2,
|
b2,
|
||||||
moe_runner_config.activation,
|
act_id,
|
||||||
moe_runner_config.apply_router_weight_on_input,
|
moe_runner_config.apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
@@ -301,7 +312,7 @@ def fused_experts(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
b1,
|
b1,
|
||||||
b2,
|
b2,
|
||||||
moe_runner_config.activation,
|
act_id,
|
||||||
moe_runner_config.apply_router_weight_on_input,
|
moe_runner_config.apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
@@ -345,7 +356,7 @@ def fused_experts_impl(
|
|||||||
b1: Optional[torch.Tensor] = None,
|
b1: Optional[torch.Tensor] = None,
|
||||||
b2: Optional[torch.Tensor] = None,
|
b2: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: int = 0,#0 silu 1 gelu
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@@ -364,6 +375,9 @@ def fused_experts_impl(
|
|||||||
gemm1_alpha: Optional[float] = None,
|
gemm1_alpha: Optional[float] = None,
|
||||||
gemm1_limit: Optional[float] = None,
|
gemm1_limit: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
if isinstance(activation, int):
|
||||||
|
activation = "silu" if activation == 0 else "gelu"
|
||||||
|
|
||||||
padded_size = padding_size
|
padded_size = padding_size
|
||||||
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||||
padded_size = 0
|
padded_size = 0
|
||||||
|
|||||||
@@ -51,14 +51,10 @@ def get_moe_configs(
|
|||||||
|
|
||||||
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
|
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
|
||||||
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
|
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
|
||||||
config_dir = os.environ.get(
|
|
||||||
"SGLANG_MOE_CONFIG_DIR", os.path.dirname(os.path.realpath(__file__))
|
|
||||||
)
|
|
||||||
|
|
||||||
triton_version = triton.__version__
|
triton_version = triton.__version__
|
||||||
version_dir = f"triton_{triton_version.replace('.', '_')}"
|
version_dir = f"triton_{triton_version.replace('.', '_')}"
|
||||||
config_file_path = os.path.join(
|
config_file_path = os.path.join(
|
||||||
config_dir,
|
os.path.dirname(os.path.realpath(__file__)),
|
||||||
"configs",
|
"configs",
|
||||||
version_dir,
|
version_dir,
|
||||||
json_file_name,
|
json_file_name,
|
||||||
@@ -79,7 +75,7 @@ def get_moe_configs(
|
|||||||
if try_triton_version == triton_version:
|
if try_triton_version == triton_version:
|
||||||
continue
|
continue
|
||||||
try_config_file_path = os.path.join(
|
try_config_file_path = os.path.join(
|
||||||
config_dir,
|
os.path.dirname(os.path.realpath(__file__)),
|
||||||
"configs",
|
"configs",
|
||||||
f"triton_{try_triton_version.replace('.', '_')}",
|
f"triton_{try_triton_version.replace('.', '_')}",
|
||||||
json_file_name,
|
json_file_name,
|
||||||
|
|||||||
@@ -575,10 +575,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
||||||
if (
|
if should_use_flashinfer_trtllm_moe():
|
||||||
should_use_flashinfer_trtllm_moe()
|
|
||||||
and self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
|
||||||
):
|
|
||||||
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
||||||
|
|
||||||
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from sglang.srt.layers.utils import pad_or_narrow_weight
|
|
||||||
from sglang.srt.utils import is_cpu
|
from sglang.srt.utils import is_cpu
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -157,17 +156,9 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not use_presharded_weights:
|
if not use_presharded_weights:
|
||||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
loaded_weight = loaded_weight.narrow(
|
||||||
start_idx = tp_rank * shard_size
|
self.output_dim, tp_rank * shard_size, shard_size
|
||||||
end_idx = start_idx + shard_size
|
)
|
||||||
if end_idx > loaded_weight.shape[self.output_dim]:
|
|
||||||
loaded_weight = pad_or_narrow_weight(
|
|
||||||
loaded_weight, self.output_dim, start_idx, shard_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loaded_weight = loaded_weight.narrow(
|
|
||||||
self.output_dim, start_idx, shard_size
|
|
||||||
)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
@@ -267,17 +258,9 @@ class RowvLLMParameter(BasevLLMParameter):
|
|||||||
|
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
loaded_weight = loaded_weight.narrow(
|
||||||
start_idx = tp_rank * shard_size
|
self.input_dim, tp_rank * shard_size, shard_size
|
||||||
end_idx = start_idx + shard_size
|
)
|
||||||
if end_idx > loaded_weight.shape[self.input_dim]:
|
|
||||||
loaded_weight = pad_or_narrow_weight(
|
|
||||||
loaded_weight, self.input_dim, start_idx, shard_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loaded_weight = loaded_weight.narrow(
|
|
||||||
self.input_dim, start_idx, shard_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
|||||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
CompressedTensorsW8A8Fp8,
|
CompressedTensorsW8A8Fp8,
|
||||||
CompressedTensorsW8A8Int8,
|
|
||||||
CompressedTensorsW8A16Fp8,
|
CompressedTensorsW8A16Fp8,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||||
|
|||||||
@@ -2,12 +2,10 @@
|
|||||||
|
|
||||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
|
||||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CompressedTensorsScheme",
|
"CompressedTensorsScheme",
|
||||||
"CompressedTensorsW8A8Fp8",
|
"CompressedTensorsW8A8Fp8",
|
||||||
"CompressedTensorsW8A16Fp8",
|
"CompressedTensorsW8A16Fp8",
|
||||||
"CompressedTensorsW8A8Int8",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,173 +0,0 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from compressed_tensors.quantization import QuantizationStrategy
|
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
from sglang.srt.layers.parameter import (
|
|
||||||
ChannelQuantScaleParameter,
|
|
||||||
ModelWeightParameter,
|
|
||||||
PerTensorScaleParameter,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
|
||||||
CompressedTensorsScheme,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
|
||||||
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
|
||||||
from sglang.srt.utils import is_cuda
|
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
|
||||||
if _is_cuda:
|
|
||||||
from sgl_kernel import int8_scaled_mm
|
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
|
||||||
):
|
|
||||||
self.strategy = strategy
|
|
||||||
self.is_static_input_scheme = is_static_input_scheme
|
|
||||||
self.input_symmetric = input_symmetric
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_min_capability(cls) -> int:
|
|
||||||
# lovelace and up
|
|
||||||
return 89
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer) -> None:
|
|
||||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
|
||||||
# tensor scales (thus N scales being passed to the kernel),
|
|
||||||
# requantize so we can always run per channel
|
|
||||||
if self.strategy == QuantizationStrategy.TENSOR:
|
|
||||||
max_w_scale, weight = requantize_with_max_scale(
|
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
logical_widths=layer.logical_widths,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
|
||||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
|
||||||
|
|
||||||
# If channelwise, scales are already lined up, so just transpose.
|
|
||||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
|
||||||
weight = layer.weight
|
|
||||||
weight_scale = layer.weight_scale.data
|
|
||||||
|
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
|
||||||
# required by torch.compile to be torch.nn.Parameter
|
|
||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
|
||||||
|
|
||||||
# INPUT SCALE
|
|
||||||
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
|
|
||||||
if self.input_symmetric:
|
|
||||||
layer.input_scale = Parameter(
|
|
||||||
layer.input_scale.max(), requires_grad=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
input_scale = layer.input_scale
|
|
||||||
input_zero_point = layer.input_zero_point
|
|
||||||
|
|
||||||
# reconstruct the ranges
|
|
||||||
int8_traits = torch.iinfo(torch.int8)
|
|
||||||
azps = input_zero_point.to(dtype=torch.int32)
|
|
||||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
|
||||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
|
||||||
|
|
||||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
|
||||||
|
|
||||||
# AZP loaded as int8 but used as int32
|
|
||||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
|
||||||
|
|
||||||
layer.input_scale = Parameter(scale, requires_grad=False)
|
|
||||||
layer.input_zero_point = Parameter(azp, requires_grad=False)
|
|
||||||
else:
|
|
||||||
layer.input_scale = None
|
|
||||||
layer.input_zero_point = None
|
|
||||||
|
|
||||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
|
||||||
# It does not depend on scales or azp, so it is the same for
|
|
||||||
# static and dynamic quantization.
|
|
||||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
|
||||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
|
||||||
if not self.input_symmetric:
|
|
||||||
weight = layer.weight
|
|
||||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
|
||||||
if self.is_static_input_scheme:
|
|
||||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
|
||||||
# in the per-tensor case
|
|
||||||
azp_adj = layer.input_zero_point * azp_adj
|
|
||||||
layer.azp_adj = Parameter(azp_adj, requires_grad=False)
|
|
||||||
else:
|
|
||||||
layer.azp_adj = None
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
output_partition_sizes: list[int],
|
|
||||||
input_size_per_partition: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
weight_loader: Callable,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
|
||||||
layer.logical_widths = output_partition_sizes
|
|
||||||
|
|
||||||
# WEIGHT
|
|
||||||
weight = ModelWeightParameter(
|
|
||||||
data=torch.empty(
|
|
||||||
output_size_per_partition, input_size_per_partition, dtype=torch.int8
|
|
||||||
),
|
|
||||||
input_dim=1,
|
|
||||||
output_dim=0,
|
|
||||||
weight_loader=weight_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
|
|
||||||
# WEIGHT SCALE
|
|
||||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
|
||||||
weight_scale = ChannelQuantScaleParameter(
|
|
||||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
|
||||||
output_dim=0,
|
|
||||||
weight_loader=weight_loader,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert self.strategy == QuantizationStrategy.TENSOR
|
|
||||||
weight_scale = PerTensorScaleParameter(
|
|
||||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
|
||||||
weight_loader=weight_loader,
|
|
||||||
)
|
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
|
||||||
|
|
||||||
# INPUT SCALE
|
|
||||||
if self.is_static_input_scheme:
|
|
||||||
input_scale = PerTensorScaleParameter(
|
|
||||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
|
||||||
)
|
|
||||||
layer.register_parameter("input_scale", input_scale)
|
|
||||||
|
|
||||||
if not self.input_symmetric:
|
|
||||||
# Note: compressed-tensors stores the zp using the same dtype
|
|
||||||
# as the weights
|
|
||||||
# AZP loaded as int8 but used as int32
|
|
||||||
input_zero_point = PerTensorScaleParameter(
|
|
||||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
|
||||||
)
|
|
||||||
layer.register_parameter("input_zero_point", input_zero_point)
|
|
||||||
|
|
||||||
def apply_weights(
|
|
||||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# TODO: add cutlass_scaled_mm_azp support
|
|
||||||
x_q, x_scale = per_token_quant_int8(x)
|
|
||||||
|
|
||||||
return int8_scaled_mm(
|
|
||||||
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
|
||||||
)
|
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
|
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -13,6 +15,7 @@ def _compute_enable_deep_gemm():
|
|||||||
try:
|
try:
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
||||||
|
|||||||
@@ -843,18 +843,10 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights = topk_weights.to(
|
topk_weights = topk_weights.to(
|
||||||
torch.float32
|
torch.float32
|
||||||
) # aiter's moe_sorting requires topk_weights to be FP32
|
) # aiter's moe_sorting requires topk_weights to be FP32
|
||||||
|
|
||||||
if hasattr(torch, "float4_e2m1fn_x2"):
|
|
||||||
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
|
||||||
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
|
|
||||||
else:
|
|
||||||
w13_weight = layer.w13_weight
|
|
||||||
w2_weight = layer.w2_weight
|
|
||||||
|
|
||||||
output = fused_moe(
|
output = fused_moe(
|
||||||
x,
|
x,
|
||||||
w13_weight,
|
layer.w13_weight,
|
||||||
w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
quant_type=QuantType.per_1x32,
|
quant_type=QuantType.per_1x32,
|
||||||
|
|||||||
@@ -183,17 +183,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|||||||
moe_runner_config = self.moe_runner_config
|
moe_runner_config = self.moe_runner_config
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
|
|
||||||
if hasattr(torch, "float4_e2m1fn_x2"):
|
|
||||||
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
|
||||||
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
|
|
||||||
else:
|
|
||||||
w13_weight = layer.w13_weight
|
|
||||||
w2_weight = layer.w2_weight
|
|
||||||
|
|
||||||
output = fused_moe(
|
output = fused_moe(
|
||||||
x,
|
x,
|
||||||
w13_weight,
|
layer.w13_weight,
|
||||||
w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
quant_type=QuantType.per_1x32,
|
quant_type=QuantType.per_1x32,
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
|||||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||||
from sglang.srt.utils import is_npu, set_weight_attrs
|
from sglang.srt.utils import is_npu, set_weight_attrs
|
||||||
|
|
||||||
|
_is_npu = is_npu()
|
||||||
|
if not _is_npu:
|
||||||
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
|
|||||||
@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
x.dtype,
|
x.dtype,
|
||||||
True, # is_vnni
|
True, # is_vnni
|
||||||
)
|
)
|
||||||
|
|
||||||
x_q, x_scale = per_token_quant_int8(x)
|
x_q, x_scale = per_token_quant_int8(x)
|
||||||
|
|
||||||
x_q_2d = x_q.view(-1, x_q.shape[-1])
|
return int8_scaled_mm(
|
||||||
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
||||||
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
|
||||||
|
|
||||||
output = int8_scaled_mm(
|
|
||||||
x_q_2d,
|
|
||||||
layer.weight,
|
|
||||||
x_scale_2d,
|
|
||||||
layer.weight_scale,
|
|
||||||
out_dtype=x.dtype,
|
|
||||||
bias=bias,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return output.view(output_shape)
|
|
||||||
|
|
||||||
|
|
||||||
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||||
"""MoE method for INT8.
|
"""MoE method for INT8.
|
||||||
@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl:
|
|||||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
|
||||||
|
|
||||||
|
|
||||||
class NPU_W8A8LinearMethodMTImpl:
|
class NPU_W8A8LinearMethodMTImpl:
|
||||||
@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
|||||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
|
||||||
|
|
||||||
|
|
||||||
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_compiler_backend,
|
|
||||||
is_cpu,
|
is_cpu,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
|||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||||
else:
|
|
||||||
FusedSetKVBufferArg = None
|
|
||||||
|
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
from aiter.rotary_embedding import get_rope as aiter_get_rope
|
from aiter.rotary_embedding import get_rope as aiter_get_rope
|
||||||
|
|
||||||
if is_npu():
|
if is_npu():
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
|
|
||||||
NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
|
|
||||||
|
|
||||||
|
|
||||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""A PyTorch-native implementation of forward()."""
|
"""A PyTorch-native implementation of forward()."""
|
||||||
assert (
|
|
||||||
fused_set_kv_buffer_arg is None
|
|
||||||
), "fused_set_kv_buffer_arg is not supported for native implementation"
|
|
||||||
|
|
||||||
if offsets is not None:
|
if offsets is not None:
|
||||||
positions = positions + offsets
|
positions = positions + offsets
|
||||||
positions = positions.flatten()
|
positions = positions.flatten()
|
||||||
@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""A PyTorch-npu implementation of forward()."""
|
"""A PyTorch-npu implementation of forward()."""
|
||||||
assert (
|
import os
|
||||||
fused_set_kv_buffer_arg is None
|
|
||||||
), "fused_set_kv_buffer_arg is not supported for npu implementation"
|
|
||||||
|
|
||||||
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
|
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
|
||||||
return self.forward_native(
|
return self.forward_native(positions, query, key, offsets)
|
||||||
positions, query, key, offsets, fused_set_kv_buffer_arg
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
rotary_mode = "half"
|
rotary_mode = "half"
|
||||||
if self.is_neox_style:
|
if self.is_neox_style:
|
||||||
@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert (
|
|
||||||
fused_set_kv_buffer_arg is None
|
|
||||||
), "fused_set_kv_buffer_arg is not supported for cpu implementation"
|
|
||||||
|
|
||||||
positions = torch.add(positions, offsets) if offsets is not None else positions
|
positions = torch.add(positions, offsets) if offsets is not None else positions
|
||||||
if _is_cpu_amx_available:
|
if _is_cpu_amx_available:
|
||||||
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||||
@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.is_neox_style,
|
self.is_neox_style,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_native(
|
return self.forward_native(positions, query, key, offsets)
|
||||||
positions, query, key, offsets, fused_set_kv_buffer_arg
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
||||||
apply_rope_with_cos_sin_cache_inplace(
|
apply_rope_with_cos_sin_cache_inplace(
|
||||||
@@ -789,7 +765,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
|
||||||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
cos_for_key = cos[:, 0, ...]
|
||||||
|
sin_for_key = sin[:, 0, ...]
|
||||||
|
key_rot = key_rot * cos_for_key + rotate_fn(key_rot) * sin_for_key
|
||||||
|
#key_rot = key_rot * cos + rotate_fn(key_rot) * sin
|
||||||
|
|
||||||
if self.rotary_dim < self.head_size:
|
if self.rotary_dim < self.head_size:
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1)
|
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
@@ -1059,7 +1038,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
@torch.compile(dynamic=True)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -1207,7 +1186,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
time_tensor_long = time_tensor.long()
|
time_tensor_long = time_tensor.long()
|
||||||
t_index = time_tensor_long.flatten()
|
t_index = time_tensor_long.flatten()
|
||||||
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
|
elif model_type == "qwen2_vl":
|
||||||
t_index = (
|
t_index = (
|
||||||
torch.arange(llm_grid_t)
|
torch.arange(llm_grid_t)
|
||||||
.view(-1, 1)
|
.view(-1, 1)
|
||||||
@@ -1918,30 +1897,17 @@ def apply_rotary_pos_emb_npu(
|
|||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
unsqueeze_dim=1,
|
unsqueeze_dim=1,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
|
if q.shape[1] != 128:
|
||||||
|
|
||||||
Args:
|
|
||||||
q: [num_tokens, num_heads, head_size]
|
|
||||||
k: [num_tokens, num_kv_heads, head_size]
|
|
||||||
cos: [num_tokens, head_size]
|
|
||||||
sin: [num_tokens, head_size]
|
|
||||||
"""
|
|
||||||
if (
|
|
||||||
cos.dim() != 2
|
|
||||||
or q.dim() != 3
|
|
||||||
or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
|
|
||||||
or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
|
|
||||||
):
|
|
||||||
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
|
|
||||||
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
||||||
cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
cos = torch.transpose(cos, 1, 2)
|
||||||
q = q.unsqueeze(0)
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
k = k.unsqueeze(0)
|
sin = torch.transpose(sin, 1, 2)
|
||||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
q = torch.transpose(q, 1, 2)
|
||||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
k = torch.transpose(k, 1, 2)
|
||||||
q_embed = q_embed.squeeze(0)
|
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
|
||||||
k_embed = k_embed.squeeze(0)
|
q_embed = torch.transpose(q_embed, 1, 2)
|
||||||
|
k_embed = torch.transpose(k_embed, 1, 2)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,29 +15,6 @@ def get_layer_id(weight_name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def pad_or_narrow_weight(
|
|
||||||
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
|
|
||||||
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
|
|
||||||
|
|
||||||
if valid_size > 0:
|
|
||||||
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
|
|
||||||
pad_shape = list(loaded_weight.shape)
|
|
||||||
pad_shape[input_dim] = shard_size - valid_size
|
|
||||||
pad = torch.zeros(
|
|
||||||
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
|
||||||
)
|
|
||||||
return torch.cat([loaded_slice, pad], dim=input_dim)
|
|
||||||
|
|
||||||
# All padding
|
|
||||||
pad_shape = list(loaded_weight.shape)
|
|
||||||
pad_shape[input_dim] = shard_size
|
|
||||||
return torch.zeros(
|
|
||||||
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PPMissingLayer(torch.nn.Identity):
|
class PPMissingLayer(torch.nn.Identity):
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
from sglang.srt.utils import cached_triton_kernel
|
from sglang.utils import cached_triton_kernel
|
||||||
|
|
||||||
|
|
||||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.lora.utils import LoRABatchInfo
|
from sglang.srt.lora.utils import LoRABatchInfo
|
||||||
from sglang.srt.utils import cached_triton_kernel
|
from sglang.utils import cached_triton_kernel
|
||||||
|
|
||||||
|
|
||||||
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
|
||||||
|
|||||||
@@ -275,17 +275,43 @@ class HiCacheController:
|
|||||||
and self.storage_config.tp_rank != 0
|
and self.storage_config.tp_rank != 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use storage backend factory for dynamic backend creation
|
if storage_backend == "file":
|
||||||
from sglang.srt.mem_cache.storage import StorageBackendFactory
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
|
||||||
|
|
||||||
try:
|
self.storage_backend = HiCacheFile(self.storage_config)
|
||||||
self.storage_backend = StorageBackendFactory.create_backend(
|
elif storage_backend == "nixl":
|
||||||
storage_backend, self.storage_config, self.mem_pool_host
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
||||||
|
|
||||||
|
self.storage_backend = HiCacheNixl()
|
||||||
|
elif storage_backend == "mooncake":
|
||||||
|
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
||||||
|
MooncakeStore,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(f"Failed to create storage backend: {e}") from e
|
|
||||||
|
|
||||||
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
|
self.storage_backend = MooncakeStore(self.storage_config)
|
||||||
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
||||||
|
assert self.mem_pool_host.layout == "page_first"
|
||||||
|
elif storage_backend == "hf3fs":
|
||||||
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
||||||
|
HiCacheHF3FS,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.mem_pool_host.layout == "page_first":
|
||||||
|
bytes_per_page = (
|
||||||
|
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
||||||
|
)
|
||||||
|
elif self.mem_pool_host.layout == "layer_first":
|
||||||
|
bytes_per_page = (
|
||||||
|
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
||||||
|
)
|
||||||
|
dtype = mem_pool_host.dtype
|
||||||
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
||||||
|
bytes_per_page, dtype, self.storage_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Unsupported storage backend: {storage_backend}"
|
||||||
|
)
|
||||||
|
|
||||||
self.enable_storage = True
|
self.enable_storage = True
|
||||||
# todo: threshold policy for prefetching
|
# todo: threshold policy for prefetching
|
||||||
@@ -309,10 +335,18 @@ class HiCacheController:
|
|||||||
# Select the get and set functions
|
# Select the get and set functions
|
||||||
self.page_get_func = self._generic_page_get
|
self.page_get_func = self._generic_page_get
|
||||||
self.page_set_func = self._generic_page_set
|
self.page_set_func = self._generic_page_set
|
||||||
|
self.batch_exists_func = self.storage_backend.batch_exists
|
||||||
if self.storage_backend_type in ["hf3fs", "mooncake"]:
|
self.is_3fs_zerocopy = (
|
||||||
self.page_get_func = self._page_get_zero_copy
|
self.storage_backend_type == "hf3fs"
|
||||||
self.page_set_func = self._page_set_zero_copy
|
and self.mem_pool_host.layout == "page_first"
|
||||||
|
)
|
||||||
|
if self.storage_backend_type == "mooncake":
|
||||||
|
self.page_get_func = self._mooncake_page_get
|
||||||
|
self.page_set_func = self._mooncake_page_set
|
||||||
|
elif self.is_3fs_zerocopy:
|
||||||
|
self.page_get_func = self._3fs_zero_copy_page_get
|
||||||
|
self.page_set_func = self._3fs_zero_copy_page_set
|
||||||
|
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
||||||
|
|
||||||
self.device = self.mem_pool_device.device
|
self.device = self.mem_pool_device.device
|
||||||
self.layer_num = self.mem_pool_device.layer_num
|
self.layer_num = self.mem_pool_device.layer_num
|
||||||
@@ -436,6 +470,7 @@ class HiCacheController:
|
|||||||
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
||||||
if host_indices is None:
|
if host_indices is None:
|
||||||
return None
|
return None
|
||||||
|
self.mem_pool_host.protect_write(host_indices)
|
||||||
self.write_queue.append(
|
self.write_queue.append(
|
||||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||||
)
|
)
|
||||||
@@ -459,6 +494,7 @@ class HiCacheController:
|
|||||||
self.mem_pool_host.backup_from_device_all_layer(
|
self.mem_pool_host.backup_from_device_all_layer(
|
||||||
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
||||||
)
|
)
|
||||||
|
self.mem_pool_host.complete_io(op.host_indices)
|
||||||
finish_event.record()
|
finish_event.record()
|
||||||
# NOTE: We must save the host indices and device indices here,
|
# NOTE: We must save the host indices and device indices here,
|
||||||
# this is because we need to guarantee that these tensors are
|
# this is because we need to guarantee that these tensors are
|
||||||
@@ -482,6 +518,7 @@ class HiCacheController:
|
|||||||
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
||||||
if device_indices is None:
|
if device_indices is None:
|
||||||
return None
|
return None
|
||||||
|
self.mem_pool_host.protect_load(host_indices)
|
||||||
self.load_queue.append(
|
self.load_queue.append(
|
||||||
CacheOperation(host_indices, device_indices, node_id, priority)
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
||||||
)
|
)
|
||||||
@@ -526,6 +563,7 @@ class HiCacheController:
|
|||||||
self.io_backend,
|
self.io_backend,
|
||||||
)
|
)
|
||||||
producer_event.complete(i)
|
producer_event.complete(i)
|
||||||
|
self.mem_pool_host.complete_io(op.host_indices)
|
||||||
# NOTE: We must save the host indices and device indices here,
|
# NOTE: We must save the host indices and device indices here,
|
||||||
# this is because we need to guarantee that these tensors are
|
# this is because we need to guarantee that these tensors are
|
||||||
# still alive when the load stream is executing.
|
# still alive when the load stream is executing.
|
||||||
@@ -543,16 +581,29 @@ class HiCacheController:
|
|||||||
)
|
)
|
||||||
return producer_id
|
return producer_id
|
||||||
|
|
||||||
def evict_device(self, device_indices: torch.Tensor) -> int:
|
def evict_device(
|
||||||
self.mem_pool_device_allocator.free(device_indices)
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||||
return len(device_indices)
|
) -> int:
|
||||||
|
if self.mem_pool_host.is_synced(host_indices):
|
||||||
|
self.mem_pool_device_allocator.free(device_indices)
|
||||||
|
self.mem_pool_host.update_backup(host_indices)
|
||||||
|
return len(device_indices)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
||||||
|
)
|
||||||
|
|
||||||
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
||||||
if not backup_only:
|
if not backup_only:
|
||||||
raise ValueError("Other eviction policies are not supported yet.")
|
raise ValueError("Other eviction policies are not supported yet.")
|
||||||
|
|
||||||
self.mem_pool_host.free(host_indices)
|
if self.mem_pool_host.is_backup(host_indices):
|
||||||
return len(host_indices)
|
self.mem_pool_host.free(host_indices)
|
||||||
|
return len(host_indices)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
||||||
|
)
|
||||||
|
|
||||||
def prefetch(
|
def prefetch(
|
||||||
self,
|
self,
|
||||||
@@ -579,19 +630,42 @@ class HiCacheController:
|
|||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
self.host_mem_release_queue.put(chunk)
|
self.host_mem_release_queue.put(chunk)
|
||||||
|
|
||||||
def _page_get_zero_copy(self, operation, hash_values, host_indices):
|
def _3fs_zero_copy_batch_exists(self, batch_hashes):
|
||||||
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
|
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
||||||
inc = 0
|
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
||||||
for i in range(len(hash_values)):
|
return hit_page_num
|
||||||
if not results[i]:
|
|
||||||
logger.warning(
|
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
||||||
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
||||||
)
|
hash_values, host_indices
|
||||||
break
|
)
|
||||||
inc += self.page_size
|
page_data = self.storage_backend.batch_get(hashes, dsts)
|
||||||
operation.increment(inc)
|
if page_data:
|
||||||
|
inc = self.page_size * len(hashes) // factor
|
||||||
|
operation.increment(inc)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _mooncake_page_get(self, operation, hash_values, host_indices):
|
||||||
|
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||||
|
hash_values,
|
||||||
|
host_indices,
|
||||||
|
self.storage_config.tp_rank,
|
||||||
|
)
|
||||||
|
get_result = self.storage_backend.batch_get(
|
||||||
|
key_strs,
|
||||||
|
target_locations=buffer_ptrs,
|
||||||
|
target_sizes=buffer_sizes,
|
||||||
|
)
|
||||||
|
if get_result != len(hash_values):
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch operation {operation.request_id} failed or partially failed."
|
||||||
|
)
|
||||||
|
if get_result != 0:
|
||||||
|
operation.increment(get_result * self.page_size)
|
||||||
|
|
||||||
# todo: deprecate
|
|
||||||
def _generic_page_get(self, operation, hash_values, host_indices):
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
||||||
dummy_page_dst = [
|
dummy_page_dst = [
|
||||||
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
||||||
@@ -681,7 +755,7 @@ class HiCacheController:
|
|||||||
batch_tokens[i : i + self.page_size], last_hash
|
batch_tokens[i : i + self.page_size], last_hash
|
||||||
)
|
)
|
||||||
batch_hashes.append(last_hash)
|
batch_hashes.append(last_hash)
|
||||||
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
|
hit_page_num = self.batch_exists_func(batch_hashes)
|
||||||
hash_value.extend(batch_hashes[:hit_page_num])
|
hash_value.extend(batch_hashes[:hit_page_num])
|
||||||
storage_query_count += hit_page_num * self.page_size
|
storage_query_count += hit_page_num * self.page_size
|
||||||
if hit_page_num < len(batch_hashes):
|
if hit_page_num < len(batch_hashes):
|
||||||
@@ -750,16 +824,34 @@ class HiCacheController:
|
|||||||
self.backup_queue.put(operation)
|
self.backup_queue.put(operation)
|
||||||
return operation.id
|
return operation.id
|
||||||
|
|
||||||
# todo: deprecate
|
# non-zero copy
|
||||||
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
||||||
data = [
|
data = [
|
||||||
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
|
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
|
||||||
for i in range(len(hash_values))
|
for i in range(len(hash_values))
|
||||||
]
|
]
|
||||||
return self.storage_backend.batch_set(hash_values, data)
|
return self.storage_backend.batch_set(hash_values, data)
|
||||||
|
|
||||||
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
|
# zero copy
|
||||||
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
|
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
|
||||||
|
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
||||||
|
hash_values,
|
||||||
|
host_indices,
|
||||||
|
self.storage_config.tp_rank,
|
||||||
|
)
|
||||||
|
success = self.storage_backend.batch_set(
|
||||||
|
key_strs,
|
||||||
|
target_locations=buffer_ptrs,
|
||||||
|
target_sizes=buffer_sizes,
|
||||||
|
)
|
||||||
|
return success
|
||||||
|
|
||||||
|
# zero copy
|
||||||
|
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
||||||
|
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
||||||
|
hash_values, host_indices
|
||||||
|
)
|
||||||
|
return self.storage_backend.batch_set(hashes, dsts)
|
||||||
|
|
||||||
# Backup batch by batch
|
# Backup batch by batch
|
||||||
def _page_backup(self, operation):
|
def _page_backup(self, operation):
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ else:
|
|||||||
Image = Any
|
Image = Any
|
||||||
|
|
||||||
|
|
||||||
# Parameters for a session
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SessionParams:
|
class SessionParams:
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
@@ -133,23 +132,18 @@ class GenerateReqInput:
|
|||||||
# Conversation id used for tracking requests
|
# Conversation id used for tracking requests
|
||||||
conversation_id: Optional[str] = None
|
conversation_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Label for the request
|
||||||
|
label: Optional[str] = None
|
||||||
|
|
||||||
# Priority for the request
|
# Priority for the request
|
||||||
priority: Optional[int] = None
|
priority: Optional[int] = None
|
||||||
|
|
||||||
# Extra key for classifying the request (e.g. cache_salt)
|
# Image gen grpc migration
|
||||||
extra_key: Optional[Union[List[str], str]] = None
|
|
||||||
|
|
||||||
# Whether to disallow logging for this request (e.g. due to ZDR)
|
|
||||||
no_logs: bool = False
|
|
||||||
|
|
||||||
# For custom metric labels
|
|
||||||
custom_labels: Optional[Dict[str, str]] = None
|
|
||||||
|
|
||||||
# (Deprecated, please use custom_labels) Label for the request
|
|
||||||
label: Optional[str] = None
|
|
||||||
# (Internal) Whether to return bytes for image generation
|
|
||||||
return_bytes: bool = False
|
return_bytes: bool = False
|
||||||
|
|
||||||
|
# For customer metric labels
|
||||||
|
customer_labels: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
def contains_mm_input(self) -> bool:
|
def contains_mm_input(self) -> bool:
|
||||||
return (
|
return (
|
||||||
has_valid_data(self.image_data)
|
has_valid_data(self.image_data)
|
||||||
@@ -548,11 +542,8 @@ class GenerateReqInput:
|
|||||||
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
||||||
),
|
),
|
||||||
conversation_id=self.conversation_id,
|
conversation_id=self.conversation_id,
|
||||||
priority=self.priority,
|
|
||||||
extra_key=self.extra_key,
|
|
||||||
no_logs=self.no_logs,
|
|
||||||
custom_labels=self.custom_labels,
|
|
||||||
label=self.label,
|
label=self.label,
|
||||||
|
priority=self.priority,
|
||||||
return_bytes=self.return_bytes,
|
return_bytes=self.return_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput:
|
|||||||
# For dp balance
|
# For dp balance
|
||||||
dp_balance_id: int = -1
|
dp_balance_id: int = -1
|
||||||
|
|
||||||
|
# Label for the request
|
||||||
|
label: Optional[str] = None
|
||||||
|
|
||||||
# Priority for the request
|
# Priority for the request
|
||||||
priority: Optional[int] = None
|
priority: Optional[int] = None
|
||||||
|
|
||||||
# Extra key for classifying the request (e.g. cache_salt)
|
# Image gen grpc migration
|
||||||
extra_key: Optional[str] = None
|
return_bytes: bool = False
|
||||||
|
|
||||||
# Whether to disallow logging for this request (e.g. due to ZDR)
|
|
||||||
no_logs: bool = False
|
|
||||||
|
|
||||||
# tracing context
|
# tracing context
|
||||||
trace_context: Optional[Dict] = None
|
trace_context: Optional[Dict] = None
|
||||||
|
|
||||||
# (Deprecated, please use custom_labels) Label for the request
|
|
||||||
label: Optional[str] = None
|
|
||||||
# (Internal) Whether to return bytes for image generation
|
|
||||||
return_bytes: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenizedGenerateReqInput:
|
class BatchTokenizedGenerateReqInput:
|
||||||
|
|||||||
@@ -507,7 +507,6 @@ def embed_mm_inputs(
|
|||||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||||
] = None,
|
] = None,
|
||||||
placeholder_tokens: dict[Modality, List[int]] = None,
|
placeholder_tokens: dict[Modality, List[int]] = None,
|
||||||
use_deepstack: bool = False,
|
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Embed multimodal inputs and integrate them with text token embeddings.
|
Embed multimodal inputs and integrate them with text token embeddings.
|
||||||
@@ -523,7 +522,7 @@ def embed_mm_inputs(
|
|||||||
Returns:
|
Returns:
|
||||||
Combined embedding tensor with multimodal content integrated
|
Combined embedding tensor with multimodal content integrated
|
||||||
"""
|
"""
|
||||||
other_info = {}
|
|
||||||
if mm_inputs_list is None:
|
if mm_inputs_list is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -533,7 +532,7 @@ def embed_mm_inputs(
|
|||||||
for mm_inputs in mm_inputs_list:
|
for mm_inputs in mm_inputs_list:
|
||||||
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
||||||
|
|
||||||
embeddings, masks, deepstack_embeddings = [], [], []
|
embeddings, masks = [], []
|
||||||
# 2. Get multimodal embedding separately
|
# 2. Get multimodal embedding separately
|
||||||
# Try get mm embedding if any
|
# Try get mm embedding if any
|
||||||
for modality in Modality.all():
|
for modality in Modality.all():
|
||||||
@@ -579,12 +578,6 @@ def embed_mm_inputs(
|
|||||||
extend_length=extend_seq_lens,
|
extend_length=extend_seq_lens,
|
||||||
items_offset_list=items_offsets,
|
items_offset_list=items_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_deepstack and embedding is not None:
|
|
||||||
embedding, deepstack_embedding = (
|
|
||||||
multimodal_model.separate_deepstack_embeds(embedding)
|
|
||||||
)
|
|
||||||
deepstack_embeddings += [deepstack_embedding]
|
|
||||||
embeddings += [embedding]
|
embeddings += [embedding]
|
||||||
masks += [mask]
|
masks += [mask]
|
||||||
|
|
||||||
@@ -598,37 +591,13 @@ def embed_mm_inputs(
|
|||||||
inputs_embeds = input_embedding(input_ids)
|
inputs_embeds = input_embedding(input_ids)
|
||||||
|
|
||||||
# 4. scatter embeddings into input embedding
|
# 4. scatter embeddings into input embedding
|
||||||
|
for embedding, mask in zip(embeddings, masks):
|
||||||
# deepstack embedding
|
|
||||||
if use_deepstack:
|
|
||||||
num_deepstack_embeddings = (
|
|
||||||
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
|
|
||||||
)
|
|
||||||
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
|
|
||||||
inputs_embeds.shape[-1] * num_deepstack_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_deepstack_embeds = torch.zeros(
|
|
||||||
deepstack_embedding_shape,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
dtype=inputs_embeds.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
other_info["input_deepstack_embeds"] = input_deepstack_embeds
|
|
||||||
|
|
||||||
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
|
|
||||||
if embedding is None or mask is None:
|
if embedding is None or mask is None:
|
||||||
continue
|
continue
|
||||||
# in-place update
|
# in-place update
|
||||||
indices = torch.where(mask.squeeze(dim=-1))[0]
|
indices = torch.where(mask.squeeze(dim=-1))[0]
|
||||||
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
return inputs_embeds
|
||||||
if use_deepstack:
|
|
||||||
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
|
|
||||||
inputs_embeds.device, inputs_embeds.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_embeds, other_info
|
|
||||||
|
|
||||||
|
|
||||||
def general_mm_embed_routine(
|
def general_mm_embed_routine(
|
||||||
@@ -640,7 +609,6 @@ def general_mm_embed_routine(
|
|||||||
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
|
||||||
] = None,
|
] = None,
|
||||||
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
||||||
use_deepstack: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -652,7 +620,6 @@ def general_mm_embed_routine(
|
|||||||
language_model: Base language model to use
|
language_model: Base language model to use
|
||||||
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
||||||
placeholder_tokens: Token IDs for multimodal placeholders
|
placeholder_tokens: Token IDs for multimodal placeholders
|
||||||
use_deepstack: Whether to use deepstack embeddings
|
|
||||||
**kwargs: Additional arguments passed to language model
|
**kwargs: Additional arguments passed to language model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -678,20 +645,16 @@ def general_mm_embed_routine(
|
|||||||
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
|
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
|
||||||
if forward_batch.mm_inputs[i] is not None
|
if forward_batch.mm_inputs[i] is not None
|
||||||
]
|
]
|
||||||
inputs_embeds, other_info = embed_mm_inputs(
|
inputs_embeds = embed_mm_inputs(
|
||||||
mm_inputs_list=mm_inputs_list,
|
mm_inputs_list=mm_inputs_list,
|
||||||
extend_prefix_lens=extend_prefix_lens,
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
extend_seq_lens=extend_seq_lens,
|
extend_seq_lens=extend_seq_lens,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
multimodal_model=multimodal_model,
|
|
||||||
input_embedding=embed_tokens,
|
input_embedding=embed_tokens,
|
||||||
|
multimodal_model=multimodal_model,
|
||||||
data_embedding_func_mapping=data_embedding_funcs,
|
data_embedding_func_mapping=data_embedding_funcs,
|
||||||
placeholder_tokens=placeholder_tokens,
|
placeholder_tokens=placeholder_tokens,
|
||||||
use_deepstack=use_deepstack,
|
|
||||||
)
|
)
|
||||||
# add for qwen3_vl deepstack
|
|
||||||
if use_deepstack:
|
|
||||||
kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
|
|
||||||
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
|
||||||
# just being defensive here
|
# just being defensive here
|
||||||
forward_batch.mm_inputs = None
|
forward_batch.mm_inputs = None
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__)
|
|||||||
PROCESSOR_MAPPING = {}
|
PROCESSOR_MAPPING = {}
|
||||||
|
|
||||||
|
|
||||||
def import_processors(package_name: str):
|
def import_processors():
|
||||||
|
package_name = "sglang.srt.multimodal.processors"
|
||||||
package = importlib.import_module(package_name)
|
package = importlib.import_module(package_name)
|
||||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||||
if not ispkg:
|
if not ispkg:
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
|
||||||
from sglang.srt.utils import get_compiler_backend
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
||||||
def _resolve_future_token_ids(input_ids, future_token_ids_map):
|
|
||||||
input_ids[:] = torch.where(
|
|
||||||
input_ids < 0,
|
|
||||||
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
|
||||||
input_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FutureMap:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_running_requests: int,
|
|
||||||
device: torch.device,
|
|
||||||
):
|
|
||||||
self.future_ct = 0
|
|
||||||
# A factor of 3 is used to avoid collision in the circular buffer.
|
|
||||||
self.future_limit = max_running_requests * 3
|
|
||||||
# A factor of 5 is used to ensure the buffer is large enough.
|
|
||||||
self.future_buffer_len = max_running_requests * 5
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.token_ids_buf = torch.empty(
|
|
||||||
(self.future_buffer_len,), dtype=torch.int64, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_ct(self, bs: int) -> int:
|
|
||||||
"""Update the circular buffer pointer and return the current pointer."""
|
|
||||||
cur_future_ct = self.future_ct
|
|
||||||
self.future_ct = (cur_future_ct + bs) % self.future_limit
|
|
||||||
return cur_future_ct
|
|
||||||
|
|
||||||
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
|
|
||||||
input_ids = model_worker_batch.input_ids
|
|
||||||
_resolve_future_token_ids(input_ids, self.token_ids_buf)
|
|
||||||
|
|
||||||
def update_next_future(self, future_ct: int, bs: int):
|
|
||||||
return torch.arange(
|
|
||||||
-(future_ct + 1),
|
|
||||||
-(future_ct + 1 + bs),
|
|
||||||
-1,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
|
|
||||||
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
|
|
||||||
@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
|||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import flatten_nested_list, support_triton
|
from sglang.srt.utils import flatten_nested_list, support_triton
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
||||||
"disable_radix_cache",
|
"disable_radix_cache",
|
||||||
"enable_dp_lm_head",
|
"enable_dp_lm_head",
|
||||||
"enable_fp32_lm_head",
|
|
||||||
"flashinfer_mxfp4_moe_precision",
|
"flashinfer_mxfp4_moe_precision",
|
||||||
"enable_flashinfer_allreduce_fusion",
|
"enable_flashinfer_allreduce_fusion",
|
||||||
"moe_dense_tp_size",
|
"moe_dense_tp_size",
|
||||||
@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"enable_custom_logit_processor",
|
"enable_custom_logit_processor",
|
||||||
"disaggregation_mode",
|
"disaggregation_mode",
|
||||||
"enable_deterministic_inference",
|
"enable_deterministic_inference",
|
||||||
|
"nsa_prefill",
|
||||||
|
"nsa_decode",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
@@ -492,7 +493,7 @@ class Req:
|
|||||||
self.custom_logit_processor = custom_logit_processor
|
self.custom_logit_processor = custom_logit_processor
|
||||||
self.return_hidden_states = return_hidden_states
|
self.return_hidden_states = return_hidden_states
|
||||||
|
|
||||||
# extra key for classifying the request (e.g. cache_salt)
|
# extra key for classifying the request (e.g. lora_id, cache_salt)
|
||||||
if lora_id is not None:
|
if lora_id is not None:
|
||||||
extra_key = (
|
extra_key = (
|
||||||
extra_key or ""
|
extra_key or ""
|
||||||
@@ -608,8 +609,6 @@ class Req:
|
|||||||
) = None
|
) = None
|
||||||
self.hidden_states: List[List[float]] = []
|
self.hidden_states: List[List[float]] = []
|
||||||
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
||||||
self.output_topk_p = None
|
|
||||||
self.output_topk_index = None
|
|
||||||
|
|
||||||
# Embedding (return values)
|
# Embedding (return values)
|
||||||
self.embedding = None
|
self.embedding = None
|
||||||
@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
|
spec_info: Optional[
|
||||||
None
|
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||||
)
|
] = None
|
||||||
|
|
||||||
# Whether to return hidden states
|
# Whether to return hidden states
|
||||||
return_hidden_states: bool = False
|
return_hidden_states: bool = False
|
||||||
@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
if (
|
if (
|
||||||
self.spec_algorithm.is_eagle()
|
self.spec_algorithm.is_eagle()
|
||||||
or self.spec_algorithm.is_standalone()
|
or self.spec_algorithm.is_standalone()
|
||||||
or self.spec_algorithm.is_ngram()
|
or self.spec_algorithm.is_lookahead()
|
||||||
):
|
):
|
||||||
# if spec decoding is used, the decode batch is prepared inside
|
# if spec decoding is used, the decode batch is prepared inside
|
||||||
# `forward_batch_speculative_generation` after running draft models.
|
# `forward_batch_speculative_generation` after running draft models.
|
||||||
@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
||||||
if self.spec_info:
|
if self.spec_info:
|
||||||
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
|
self.spec_info.filter_batch(keep_indices_device)
|
||||||
has_been_filtered = False
|
|
||||||
else:
|
|
||||||
has_been_filtered = True
|
|
||||||
self.spec_info.filter_batch(
|
|
||||||
new_indices=keep_indices_device,
|
|
||||||
has_been_filtered=has_been_filtered,
|
|
||||||
)
|
|
||||||
|
|
||||||
def merge_batch(self, other: "ScheduleBatch"):
|
def merge_batch(self, other: "ScheduleBatch"):
|
||||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||||
@@ -1992,9 +1984,9 @@ class ModelWorkerBatch:
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
|
spec_info: Optional[
|
||||||
None
|
Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
|
||||||
)
|
] = None
|
||||||
# If set, the output of the batch contains the hidden states of the run.
|
# If set, the output of the batch contains the hidden states of the run.
|
||||||
capture_hidden_mode: CaptureHiddenMode = None
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
hicache_consumer_index: int = -1
|
hicache_consumer_index: int = -1
|
||||||
|
|||||||
@@ -318,6 +318,7 @@ class PrefillAdder:
|
|||||||
new_token_ratio: float,
|
new_token_ratio: float,
|
||||||
rem_input_tokens: int,
|
rem_input_tokens: int,
|
||||||
rem_chunk_tokens: Optional[int],
|
rem_chunk_tokens: Optional[int],
|
||||||
|
max_prefill_bs: Optional[int],
|
||||||
mixed_with_decode_tokens: int = 0,
|
mixed_with_decode_tokens: int = 0,
|
||||||
priority_scheduling_preemption_threshold: int = 0,
|
priority_scheduling_preemption_threshold: int = 0,
|
||||||
):
|
):
|
||||||
@@ -358,6 +359,10 @@ class PrefillAdder:
|
|||||||
priority_scheduling_preemption_threshold
|
priority_scheduling_preemption_threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.max_prefill_bs = (
|
||||||
|
max_prefill_bs if max_prefill_bs is not None else 2147483647
|
||||||
|
)
|
||||||
|
|
||||||
def _get_running_request_total_token_offset(self, req: Req) -> int:
|
def _get_running_request_total_token_offset(self, req: Req) -> int:
|
||||||
return (
|
return (
|
||||||
min(
|
min(
|
||||||
@@ -549,6 +554,9 @@ class PrefillAdder:
|
|||||||
def add_one_req(
|
def add_one_req(
|
||||||
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
|
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
|
||||||
):
|
):
|
||||||
|
if len(self.can_run_list) >= self.max_prefill_bs:
|
||||||
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
||||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||||
|
|
||||||
|
|||||||
@@ -44,9 +44,6 @@ from sglang.srt.disaggregation.decode import (
|
|||||||
DecodeTransferQueue,
|
DecodeTransferQueue,
|
||||||
SchedulerDisaggregationDecodeMixin,
|
SchedulerDisaggregationDecodeMixin,
|
||||||
)
|
)
|
||||||
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
|
|
||||||
DecodeKVCacheOffloadManager,
|
|
||||||
)
|
|
||||||
from sglang.srt.disaggregation.prefill import (
|
from sglang.srt.disaggregation.prefill import (
|
||||||
PrefillBootstrapQueue,
|
PrefillBootstrapQueue,
|
||||||
SchedulerDisaggregationPrefillMixin,
|
SchedulerDisaggregationPrefillMixin,
|
||||||
@@ -262,7 +259,7 @@ class Scheduler(
|
|||||||
self.enable_metrics_for_all_schedulers = (
|
self.enable_metrics_for_all_schedulers = (
|
||||||
server_args.enable_metrics_for_all_schedulers
|
server_args.enable_metrics_for_all_schedulers
|
||||||
)
|
)
|
||||||
self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
|
self.enable_kv_cache_events = server_args.kv_events_config is not None
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
@@ -388,10 +385,10 @@ class Scheduler(
|
|||||||
target_worker=self.tp_worker,
|
target_worker=self.tp_worker,
|
||||||
dp_rank=dp_rank,
|
dp_rank=dp_rank,
|
||||||
)
|
)
|
||||||
elif self.spec_algorithm.is_ngram():
|
elif self.spec_algorithm.is_lookahead():
|
||||||
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker
|
||||||
|
|
||||||
self.draft_worker = NGRAMWorker(
|
self.draft_worker = LOOKAHEADWorker(
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
moe_ep_rank=moe_ep_rank,
|
moe_ep_rank=moe_ep_rank,
|
||||||
@@ -556,11 +553,9 @@ class Scheduler(
|
|||||||
|
|
||||||
# Init metrics stats
|
# Init metrics stats
|
||||||
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
||||||
|
self.init_kv_events(server_args.kv_events_config)
|
||||||
self.init_dp_balance(dp_balance_meta)
|
self.init_dp_balance(dp_balance_meta)
|
||||||
|
|
||||||
if self.enable_kv_cache_events:
|
|
||||||
self.init_kv_events(server_args.kv_events_config)
|
|
||||||
|
|
||||||
# Init disaggregation
|
# Init disaggregation
|
||||||
self.disaggregation_mode = DisaggregationMode(
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
self.server_args.disaggregation_mode
|
self.server_args.disaggregation_mode
|
||||||
@@ -618,6 +613,8 @@ class Scheduler(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.max_prefill_bs = server_args.max_prefill_bs
|
||||||
|
|
||||||
def init_deterministic_inference_config(self):
|
def init_deterministic_inference_config(self):
|
||||||
"""Initialize deterministic inference configuration for different attention backends."""
|
"""Initialize deterministic inference configuration for different attention backends."""
|
||||||
if not self.server_args.enable_deterministic_inference:
|
if not self.server_args.enable_deterministic_inference:
|
||||||
@@ -758,24 +755,6 @@ class Scheduler(
|
|||||||
eviction_policy=server_args.radix_eviction_policy,
|
eviction_policy=server_args.radix_eviction_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
|
||||||
server_args.disaggregation_mode == "decode"
|
|
||||||
and server_args.disaggregation_decode_enable_offload_kvcache
|
|
||||||
):
|
|
||||||
self.decode_offload_manager = DecodeKVCacheOffloadManager(
|
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
|
||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
||||||
tp_group=(
|
|
||||||
self.attn_tp_cpu_group
|
|
||||||
if self.server_args.enable_dp_attention
|
|
||||||
else self.tp_cpu_group
|
|
||||||
),
|
|
||||||
tree_cache=self.tree_cache,
|
|
||||||
server_args=self.server_args,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.decode_offload_manager = None
|
|
||||||
|
|
||||||
self.decode_mem_cache_buf_multiplier = (
|
self.decode_mem_cache_buf_multiplier = (
|
||||||
1
|
1
|
||||||
if self.spec_algorithm.is_none()
|
if self.spec_algorithm.is_none()
|
||||||
@@ -806,7 +785,7 @@ class Scheduler(
|
|||||||
self.disagg_metadata_buffers = MetadataBuffers(
|
self.disagg_metadata_buffers = MetadataBuffers(
|
||||||
buffer_size,
|
buffer_size,
|
||||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||||
hidden_states_dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -826,7 +805,7 @@ class Scheduler(
|
|||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
draft_token_to_kv_pool=(
|
draft_token_to_kv_pool=(
|
||||||
None
|
None
|
||||||
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
|
||||||
else self.draft_worker.model_runner.token_to_kv_pool
|
else self.draft_worker.model_runner.token_to_kv_pool
|
||||||
),
|
),
|
||||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||||
@@ -855,7 +834,7 @@ class Scheduler(
|
|||||||
self.disagg_metadata_buffers = MetadataBuffers(
|
self.disagg_metadata_buffers = MetadataBuffers(
|
||||||
buffer_size,
|
buffer_size,
|
||||||
hidden_size=self.model_config.hf_text_config.hidden_size,
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
||||||
hidden_states_dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -863,7 +842,7 @@ class Scheduler(
|
|||||||
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||||
draft_token_to_kv_pool=(
|
draft_token_to_kv_pool=(
|
||||||
None
|
None
|
||||||
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
|
||||||
else self.draft_worker.model_runner.token_to_kv_pool
|
else self.draft_worker.model_runner.token_to_kv_pool
|
||||||
),
|
),
|
||||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||||
@@ -1832,6 +1811,7 @@ class Scheduler(
|
|||||||
self.new_token_ratio,
|
self.new_token_ratio,
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
self.chunked_prefill_size,
|
self.chunked_prefill_size,
|
||||||
|
self.max_prefill_bs,
|
||||||
running_bs if self.is_mixed_chunk else 0,
|
running_bs if self.is_mixed_chunk else 0,
|
||||||
self.priority_scheduling_preemption_threshold,
|
self.priority_scheduling_preemption_threshold,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -250,13 +250,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
|
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
if req.finished():
|
if req.finished():
|
||||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
self.tree_cache.cache_finished_req(req)
|
||||||
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
|
||||||
if not self.decode_offload_manager.offload_kv_cache(req):
|
|
||||||
self.tree_cache.cache_finished_req(req)
|
|
||||||
else:
|
|
||||||
self.tree_cache.cache_finished_req(req)
|
|
||||||
|
|
||||||
req.time_stats.completion_time = time.time()
|
req.time_stats.completion_time = time.time()
|
||||||
|
|
||||||
if req.return_logprob and batch.spec_algorithm.is_none():
|
if req.return_logprob and batch.spec_algorithm.is_none():
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
|
|||||||
def start_profile(
|
def start_profile(
|
||||||
self, stage: Optional[ForwardMode] = None
|
self, stage: Optional[ForwardMode] = None
|
||||||
) -> ProfileReqOutput | None:
|
) -> ProfileReqOutput | None:
|
||||||
stage_str = f" for {stage.name}" if stage else ""
|
stage_str = f" for {stage.__str__()}" if stage else ""
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
||||||
)
|
)
|
||||||
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
|
|||||||
if not Path(self.torch_profiler_output_dir).exists():
|
if not Path(self.torch_profiler_output_dir).exists():
|
||||||
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
stage_suffix = f"-{stage.name}" if stage else ""
|
stage_suffix = f"-{stage.__str__()}" if stage else ""
|
||||||
logger.info("Stop profiling" + stage_suffix + "...")
|
logger.info("Stop profiling" + stage_suffix + "...")
|
||||||
if self.torch_profiler is not None:
|
if self.torch_profiler is not None:
|
||||||
self.torch_profiler.stop()
|
self.torch_profiler.stop()
|
||||||
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
|
|||||||
if self.profiler_decode_ct == 0:
|
if self.profiler_decode_ct == 0:
|
||||||
if self.profile_in_progress:
|
if self.profile_in_progress:
|
||||||
# force trace flush
|
# force trace flush
|
||||||
self.stop_profile(stage=ForwardMode.EXTEND)
|
self.stop_profile(ForwardMode.EXTEND)
|
||||||
self.start_profile(batch.forward_mode)
|
self.start_profile(batch.forward_mode)
|
||||||
self.profiler_decode_ct += 1
|
self.profiler_decode_ct += 1
|
||||||
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
||||||
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
|
|||||||
recv_req.profile_by_stage,
|
recv_req.profile_by_stage,
|
||||||
recv_req.profile_id,
|
recv_req.profile_id,
|
||||||
)
|
)
|
||||||
return self.start_profile()
|
return self.start_profile(True)
|
||||||
else:
|
else:
|
||||||
return self.stop_profile()
|
return self.stop_profile()
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.model_config.is_multimodal:
|
if self.model_config.is_multimodal:
|
||||||
import_processors("sglang.srt.multimodal.processors")
|
import_processors()
|
||||||
try:
|
try:
|
||||||
_processor = get_processor(
|
_processor = get_processor(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
@@ -320,8 +320,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
"model_name": self.server_args.served_model_name,
|
"model_name": self.server_args.served_model_name,
|
||||||
# TODO: Add lora name/path in the future,
|
# TODO: Add lora name/path in the future,
|
||||||
}
|
}
|
||||||
if server_args.tokenizer_metrics_allowed_custom_labels:
|
if server_args.tokenizer_metrics_allowed_customer_labels:
|
||||||
for label in server_args.tokenizer_metrics_allowed_custom_labels:
|
for label in server_args.tokenizer_metrics_allowed_customer_labels:
|
||||||
labels[label] = ""
|
labels[label] = ""
|
||||||
self.metrics_collector = TokenizerMetricsCollector(
|
self.metrics_collector = TokenizerMetricsCollector(
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
@@ -750,7 +750,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
return_hidden_states=obj.return_hidden_states,
|
return_hidden_states=obj.return_hidden_states,
|
||||||
data_parallel_rank=obj.data_parallel_rank,
|
data_parallel_rank=obj.data_parallel_rank,
|
||||||
priority=obj.priority,
|
priority=obj.priority,
|
||||||
extra_key=obj.extra_key,
|
|
||||||
)
|
)
|
||||||
elif isinstance(obj, EmbeddingReqInput):
|
elif isinstance(obj, EmbeddingReqInput):
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
@@ -1633,10 +1632,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_labels = getattr(state.obj, "custom_labels", None)
|
customer_labels = getattr(state.obj, "customer_labels", None)
|
||||||
labels = (
|
labels = (
|
||||||
{**self.metrics_collector.labels, **custom_labels}
|
{**self.metrics_collector.labels, **customer_labels}
|
||||||
if custom_labels
|
if customer_labels
|
||||||
else self.metrics_collector.labels
|
else self.metrics_collector.labels
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ class TpModelWorker:
|
|||||||
else server_args.speculative_draft_model_revision
|
else server_args.speculative_draft_model_revision
|
||||||
),
|
),
|
||||||
is_draft_model=is_draft_worker,
|
is_draft_model=is_draft_worker,
|
||||||
|
tp_rank=tp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user