Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -96,7 +96,10 @@ dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
[tool.setuptools.package-data]
"sglang" = ["srt/layers/moe/fused_moe_triton/configs/*.json", "srt/layers/quantization/configs/*.json"]
"sglang" = [
"srt/layers/moe/fused_moe_triton/configs/*.json",
"srt/layers/quantization/configs/*.json",
]
[tool.setuptools.packages.find]
exclude = [

View File

@@ -8,8 +8,10 @@
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
- `bench_serving.py`: Benchmark online serving with dynamic requests.
- `check_env.py`: Check the environment variables.
- `check_env.py`: Check the environment variables and dependencies.
- `global_config.py`: The global configs and constants.
- `launch_server.py`: The entry point for launching the local server.
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
- `profiler.py`: Profile a running server.
- `utils.py`: Common utilities.
- `version.py`: Version info.

View File

@@ -56,6 +56,7 @@ class BenchArgs:
profile: bool = False
skip_warmup: bool = False
do_not_exit: bool = False
prompt_suffix: str = ""
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
@@ -177,6 +178,12 @@ class BenchArgs:
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
@@ -216,6 +223,10 @@ def throughput_test_once(
]
if profile:
assert (
"SGLANG_TORCH_PROFILER_DIR" in os.environ
), "Please set SGLANG_TORCH_PROFILER_DIR."
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
backend.start_profile()
st = time.perf_counter()
@@ -229,6 +240,8 @@ def throughput_test_once(
if backend_name == "runtime":
gen_out = json.loads(gen_out)
server_info = backend.get_server_info()
measurement_results["total_latency"] = latency
measurement_results["total_output_tokens"] = sum(
o["meta_info"]["completion_tokens"] for o in gen_out
@@ -246,6 +259,7 @@ def throughput_test_once(
measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"]
) / latency
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
return measurement_results
@@ -361,6 +375,11 @@ def throughput_test(
print(
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
)
print(
"{:<40} {:<10.2f}".format(
"Last generation throughput (tok/s):", result["last_gen_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", result["request_throughput"]

View File

@@ -8,7 +8,6 @@ Usage:
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
"""
import argparse
@@ -71,6 +70,10 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
def remove_suffix(text: str, suffix: str) -> str:
return text[: -len(suffix)] if text.endswith(suffix) else text
def get_auth_headers() -> Dict[str, str]:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
@@ -79,7 +82,7 @@ def get_auth_headers() -> Dict[str, str]:
return {}
# trt llm not support ignore_eos
# trt llm does not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
@@ -179,6 +182,7 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
@@ -215,11 +219,14 @@ async def async_request_openai_completions(
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
output_len = data.get("usage", {}).get(
"completion_tokens", output_len
)
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
output.output_len = output_len
else:
output.error = response.reason or ""
output.success = False
@@ -339,9 +346,11 @@ async def async_request_sglang_generate(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
last_output_len = 0
try:
async with session.post(
url=api_url, json=payload, headers=headers
@@ -365,6 +374,9 @@ async def async_request_sglang_generate(
# want to check a token was generated
if data["text"]:
timestamp = time.perf_counter()
generated_text = data["text"]
output_len = data["meta_info"]["completion_tokens"]
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
@@ -372,7 +384,13 @@ async def async_request_sglang_generate(
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
num_new_tokens = output_len - last_output_len
if num_new_tokens == 0:
continue
adjust_itl = (
timestamp - most_recent_timestamp
) / num_new_tokens
output.itl.extend([adjust_itl] * num_new_tokens)
most_recent_timestamp = timestamp
generated_text = data["text"]
@@ -380,7 +398,7 @@ async def async_request_sglang_generate(
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
output.output_len = output_len
else:
output.error = response.reason or ""
output.success = False
@@ -388,6 +406,7 @@ async def async_request_sglang_generate(
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
print(f"{output.error=}")
if pbar:
pbar.update(1)
@@ -461,6 +480,7 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
prompt_suffix=args.prompt_suffix,
apply_chat_template=args.apply_chat_template,
)
elif args.dataset_name == "random":
@@ -521,7 +541,9 @@ class BenchmarkMetrics:
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
p95_itl_ms: float
p99_itl_ms: float
max_itl_ms: float
mean_e2e_latency_ms: float
median_e2e_latency_ms: float
std_e2e_latency_ms: float
@@ -572,6 +594,7 @@ def sample_sharegpt_requests(
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
prompt_suffix: Optional[str] = "",
apply_chat_template=False,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
@@ -584,11 +607,19 @@ def sample_sharegpt_requests(
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
(
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
for data in dataset
]
@@ -603,6 +634,8 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
prompt = dataset[i][0]
if prompt_suffix:
prompt = prompt
if apply_chat_template:
prompt = tokenizer.apply_chat_template(
@@ -666,10 +699,17 @@ def sample_random_requests(
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
(
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
for data in dataset
]
# Shuffle the dataset.
@@ -895,7 +935,9 @@ def calculate_metrics(
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
max_itl_ms=np.max(itls or 0) * 1000,
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
@@ -919,6 +961,7 @@ async def benchmark(
lora_name: str,
extra_request_body: Dict[str, Any],
profile: bool,
pd_seperated: bool = False,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -1004,6 +1047,17 @@ async def benchmark(
if pbar is not None:
pbar.close()
if "sglang" in backend:
server_info = requests.get(base_url + "/get_server_info")
if pd_seperated:
accept_length = server_info.json()["decode"][0].get(
"avg_spec_accept_length", None
)
else:
accept_length = server_info.json().get("avg_spec_accept_length", None)
else:
accept_length = None
# Compute metrics and print results
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, output_lens = calculate_metrics(
@@ -1053,6 +1107,8 @@ async def benchmark(
)
)
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
if accept_length:
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
@@ -1066,16 +1122,12 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print(
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
)
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
print("=" * 50)
if (
@@ -1117,8 +1169,10 @@ async def benchmark(
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p95_itl_ms": metrics.p95_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"concurrency": metrics.concurrency,
"accept_length": accept_length,
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
@@ -1151,14 +1205,6 @@ async def benchmark(
return result
def parse_request_rate_range(request_rate_range):
if len(request_rate_range.split(",")) == 3:
start, stop, step = map(int, request_rate_range.split(","))
return list(range(start, stop, step))
else:
return list(map(int, request_rate_range.split(",")))
def check_chat_template(model_path):
try:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@@ -1168,6 +1214,12 @@ def check_chat_template(model_path):
return False
def set_global_args(args_: argparse.Namespace):
"""Set the global args."""
global args
args = args_
def run_benchmark(args_: argparse.Namespace):
global args
args = args_
@@ -1176,6 +1228,8 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "max_concurrency"):
args.max_concurrency = None
print(f"benchmark_args={args}")
# Set global environments
set_ulimit()
random.seed(args.seed)
@@ -1272,49 +1326,26 @@ def run_benchmark(args_: argparse.Namespace):
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer = get_tokenizer(tokenizer_id)
input_requests = get_dataset(args, tokenizer)
if not args.multi:
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
pd_seperated=args.pd_seperated,
)
else:
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
request_rates = parse_request_rate_range(args.request_rate_range)
for rate in request_rates:
asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
)
)
def set_ulimit(target_soft_limit=65535):
@@ -1428,17 +1459,6 @@ if __name__ == "__main__":
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--multi",
action="store_true",
help="Use request rate range rather than single value.",
)
parser.add_argument(
"--request-rate-range",
type=str,
default="2,34,2",
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
)
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
parser.add_argument(
"--disable-tqdm",
@@ -1485,6 +1505,17 @@ if __name__ == "__main__":
default=None,
help="The name of LoRA adapter",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
parser.add_argument(
"--pd-seperated",
action="store_true",
help="Benchmark PD disaggregation server",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(

View File

@@ -34,11 +34,9 @@ class GlobalConfig:
self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True
# Interpreter optimization configs
# Language frontend interpreter optimization configs
self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True
self.enable_flashinfer_mla = False
global_config = GlobalConfig()

View File

@@ -329,7 +329,12 @@ class RuntimeEndpoint(BaseBackend):
def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values)
try:
return sum(values) / len(values)
except TypeError:
print(f"{input_logprobs=}", flush=True)
print(f"{input_logprobs[0]=}", flush=True)
exit(-1)
class Runtime:

View File

@@ -21,6 +21,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
LAYERED = "layered"
JAX = "jax"
@dataclass
@@ -42,13 +43,15 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}

View File

@@ -44,6 +44,7 @@ class ModelConfig:
is_embedding: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
override_config_file: Optional[str] = None,
) -> None:
self.model_path = model_path
self.revision = revision
@@ -51,11 +52,16 @@ class ModelConfig:
# Parse args
self.model_override_args = json.loads(model_override_args)
kwargs = {}
if override_config_file and override_config_file.strip():
kwargs["_configuration_file"] = override_config_file.strip()
self.hf_config = get_config(
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
**kwargs,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
@@ -64,6 +70,9 @@ class ModelConfig:
self.hf_config.architectures, is_embedding
)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
self.is_audio_model = is_audio_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
@@ -71,7 +80,9 @@ class ModelConfig:
derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None:
if context_length > derived_context_len:
if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"):
if get_bool_env_var(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
):
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors."
@@ -416,6 +427,8 @@ def is_multimodal_model(model_architectures: List[str]):
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "Grok1VForCausalLM" in model_architectures
or "Grok1AForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
or "Qwen2_5_VLForConditionalGeneration" in model_architectures
@@ -426,6 +439,18 @@ def is_multimodal_model(model_architectures: List[str]):
return False
def is_multimodal_gen_model(model_architectures: List[str]):
return False
def is_image_gen_model(model_architectures: List[str]):
return False
def is_audio_model(model_architectures: List[str]):
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures

View File

@@ -15,7 +15,7 @@
import json
import logging
from typing import List, Tuple
from typing import List, Optional, Tuple, Union
import torch
from xgrammar import (
@@ -42,11 +42,16 @@ MAX_ROLLBACK_TOKENS = 200
class XGrammarGrammar(BaseGrammarObject):
def __init__(
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
self,
matcher: GrammarMatcher,
vocab_size: int,
ctx: CompiledGrammar,
override_stop_tokens: Optional[Union[List[int], int]],
) -> None:
self.matcher = matcher
self.vocab_size = vocab_size
self.ctx = ctx
self.override_stop_tokens = override_stop_tokens
self.finished = False
def accept_token(self, token: int):
@@ -96,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject):
apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self):
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
matcher = GrammarMatcher(
self.ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
override_stop_tokens=self.override_stop_tokens,
)
return XGrammarGrammar(
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
)
class XGrammarGrammarBackend(BaseGrammarBackend):
@@ -111,8 +122,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
tokenizer_info = TokenizerInfo.from_huggingface(
tokenizer, vocab_size=vocab_size
)
override_stop_tokens = None
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
self.vocab_size = vocab_size
self.override_stop_tokens = override_stop_tokens
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
@@ -161,7 +175,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
raise ValueError(f"Invalid key_type: {key_type}")
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
return XGrammarGrammar(matcher, self.vocab_size, ctx)
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
def reset(self):
if self.grammar_compiler:

View File

@@ -121,6 +121,7 @@ class Engine:
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
return_hidden_states: bool = False,
@@ -142,6 +143,7 @@ class Engine:
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
modalities=modalities_list,
custom_logit_processor=custom_logit_processor,
@@ -179,6 +181,7 @@ class Engine:
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
stream: bool = False,
@@ -195,6 +198,7 @@ class Engine:
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
stream=stream,
custom_logit_processor=custom_logit_processor,
@@ -226,15 +230,22 @@ class Engine:
kill_process_tree(os.getpid(), include_parent=False)
def start_profile(self):
self.tokenizer_manager.start_profile()
loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.start_profile())
def stop_profile(self):
self.tokenizer_manager.stop_profile()
def get_server_info(self):
loop = asyncio.get_event_loop()
internal_states = loop.run_until_complete(
self.tokenizer_manager.get_internal_state()
)
return {
**dataclasses.asdict(self.tokenizer_manager.server_args), # server args
**dataclasses.asdict(self.tokenizer_manager.server_args),
**self.scheduler_info,
**internal_states,
"version": __version__,
}
@@ -323,6 +334,7 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
# Set prometheus env vars
if server_args.enable_metrics:
@@ -346,12 +358,23 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.",
)
def sigchld_handler(signum, frame):
pid, exitcode = os.waitpid(0, os.WNOHANG)
if exitcode != 0:
logger.warning(
"Child process unexpectedly failed with an exit code %d. pid=%d",
exitcode,
pid,
)
signal.signal(signal.SIGCHLD, sigchld_handler)
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def sigquit_handler(signum, frame):
logger.error(
"Received sigquit from a child proces. It usually means the child failed."
"Received sigquit from a child process. It usually means the child failed."
)
kill_process_tree(os.getpid())

View File

@@ -25,11 +25,14 @@ import os
import threading
import time
from http import HTTPStatus
from typing import AsyncIterator, Dict, Optional
from typing import AsyncIterator, Callable, Dict, Optional
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
from contextlib import asynccontextmanager
import numpy as np
import orjson
import requests
import uvicorn
@@ -49,8 +52,10 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
ProfileReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
SetInternalStateReq,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
VertexGenerateReqInput,
@@ -78,22 +83,13 @@ from sglang.srt.utils import (
kill_process_tree,
set_uvicorn_logging_configs,
)
from sglang.srt.warmup import execute_warmups
from sglang.utils import get_exception_traceback
from sglang.version import __version__
logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Fast API
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Store global states
@dataclasses.dataclass
@@ -110,6 +106,34 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager
)
logger.info("Warmup ended")
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
# Fast API
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
##### Native API endpoints #####
@@ -123,24 +147,48 @@ async def health() -> Response:
async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token."""
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
rid = f"HEALTH_CHECK_{time.time()}"
if _global_state.tokenizer_manager.is_generation:
if _global_state.tokenizer_manager.is_image_gen:
raise NotImplementedError()
elif _global_state.tokenizer_manager.is_generation:
gri = GenerateReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
rid=rid,
input_ids=[0],
sampling_params=sampling_params,
log_metrics=False,
)
else:
gri = EmbeddingReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
)
try:
async def gen():
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break
return Response(status_code=200)
except Exception as e:
logger.exception(e)
return Response(status_code=503)
tic = time.time()
task = asyncio.create_task(gen())
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
await asyncio.sleep(1)
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
return Response(status_code=200)
task.cancel()
tic_time = time.strftime("%H:%M:%S", time.localtime(tic))
last_receive_time = time.strftime(
"%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp)
)
logger.error(
f"Health check failed. Server couldn't get a response from detokenizer for last "
f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. "
f"last_heartbeat time: {last_receive_time}"
)
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
return Response(status_code=503)
@app.get("/get_model_info")
@@ -156,13 +204,21 @@ async def get_model_info():
@app.get("/get_server_info")
async def get_server_info():
internal_states = await _global_state.tokenizer_manager.get_internal_state()
return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info,
**internal_states,
"version": __version__,
}
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)
return res
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
async def generate_request(obj: GenerateReqInput, request: Request):
@@ -179,6 +235,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
logger.error(f"Error: {e}")
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
@@ -236,9 +293,14 @@ async def flush_cache():
@app.api_route("/start_profile", methods=["GET", "POST"])
async def start_profile_async():
async def start_profile_async(obj: Optional[ProfileReqInput] = None):
"""Start profiling."""
_global_state.tokenizer_manager.start_profile()
if obj is None:
obj = ProfileReqInput()
await _global_state.tokenizer_manager.start_profile(
obj.output_dir, obj.num_steps, obj.activities
)
return Response(
content="Start profiling.\n",
status_code=200,
@@ -257,11 +319,15 @@ async def stop_profile_async():
@app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk in-place without re-launching the server."""
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
obj, request
"""Update the weights from disk inplace without re-launching the server."""
success, message, num_paused_requests = (
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
)
content = {"success": success, "message": message}
content = {
"success": success,
"message": message,
"num_paused_requests": num_paused_requests,
}
if success:
return ORJSONResponse(
content,
@@ -323,7 +389,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
async def release_memory_occupation(
obj: ReleaseMemoryOccupationReqInput, request: Request
):
"""Release GPU occupation temporarily"""
"""Release GPU memory occupation temporarily."""
try:
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
except Exception as e:
@@ -334,7 +400,7 @@ async def release_memory_occupation(
async def resume_memory_occupation(
obj: ResumeMemoryOccupationReqInput, request: Request
):
"""Resume GPU occupation"""
"""Resume GPU memory occupation."""
try:
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
except Exception as e:
@@ -357,7 +423,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
@app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request):
"""Close the session"""
"""Close the session."""
try:
await _global_state.tokenizer_manager.close_session(obj, request)
return Response(status_code=200)
@@ -367,7 +433,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
@app.api_route("/configure_logging", methods=["GET", "POST"])
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
"""Close the session"""
"""Configure the request logging options."""
_global_state.tokenizer_manager.configure_logging(obj)
return Response(status_code=200)
@@ -511,6 +577,7 @@ def _create_error_response(e):
def launch_server(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
launch_callback: Optional[Callable[[], None]] = None,
):
"""
Launch SRT (SGLang Runtime) Server.
@@ -544,21 +611,23 @@ def launch_server(
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request
t = threading.Thread(
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
_global_state.tokenizer_manager.image_token_id,
launch_callback,
),
)
t.start()
app.warmup_thread = warmup_thread
try:
# Update logging configs
set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests
uvicorn.run(
app,
@@ -569,10 +638,15 @@ def launch_server(
loop="uvloop",
)
finally:
t.join()
warmup_thread.join()
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
def _wait_and_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None,
):
headers = {}
url = server_args.url()
if server_args.api_key:
@@ -614,8 +688,16 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
else:
json_data["text"] = "The capital city of France is"
# Debug dumping
if server_args.debug_tensor_dump_input_file:
json_data.pop("text", None)
json_data["input_ids"] = np.load(
server_args.debug_tensor_dump_input_file
).tolist()
json_data["sampling_params"]["max_new_tokens"] = 0
try:
for _ in range(server_args.dp_size):
for i in range(server_args.dp_size):
res = requests.post(
url + request_name,
json=json_data,
@@ -640,3 +722,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path)
if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid())
if launch_callback is not None:
launch_callback()

View File

@@ -60,6 +60,7 @@ class VerlEngine:
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
lora_path: Optional[List[Optional[str]]] = None,
custom_logit_processor: Optional[Union[List[str], str]] = None,
) -> Dict:
@@ -76,6 +77,7 @@ class VerlEngine:
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
token_ids_logprob=token_ids_logprob,
lora_path=lora_path,
custom_logit_processor=custom_logit_processor,
)

View File

@@ -1,14 +1,14 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
import torch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
class AttentionBackend(ABC):
@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
@@ -64,7 +64,14 @@ class AttentionBackend(ABC):
):
"""Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
return self.forward_decode(
q,
k,
v,
layer,
forward_batch,
save_kv_cache=save_kv_cache,
)
else:
return self.forward_extend(
q,
@@ -72,7 +79,7 @@ class AttentionBackend(ABC):
v,
layer,
forward_batch,
save_kv_cache,
save_kv_cache=save_kv_cache,
)
def forward_decode(

View File

@@ -68,6 +68,7 @@ class FlashInferAttnBackend(AttentionBackend):
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
):
super().__init__()
@@ -125,9 +126,14 @@ class FlashInferAttnBackend(AttentionBackend):
assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf]
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
if kv_last_page_len_buf is None:
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
else:
assert self.num_wrappers == 1
self.kv_last_page_len = kv_last_page_len_buf
self.qo_indptr = [
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
for _ in range(self.num_wrappers)
@@ -922,6 +928,9 @@ class FlashInferMultiStepDraftBackend:
dtype=torch.int32,
device=model_runner.device,
)
self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
@@ -929,6 +938,7 @@ class FlashInferMultiStepDraftBackend:
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=self.kv_last_page_len,
)
)
self.max_context_len = self.attn_backends[0].max_context_len

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
import torch
import triton
@@ -15,7 +15,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
class TritonAttnBackend(AttentionBackend):
@@ -232,7 +232,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
assert encoder_lens is None, "Not supported"
@@ -310,7 +310,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
# NOTE: encoder_lens expected to be zeros or None
if forward_mode.is_decode_or_idle():

View File

@@ -1,6 +1,21 @@
import torch
from __future__ import annotations
from sglang.srt.distributed import GroupCoordinator, get_tp_group
import functools
from typing import TYPE_CHECKING, Union
import torch
import triton
import triton.language as tl
from sglang.srt.distributed import (
GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None
@@ -69,3 +84,129 @@ def get_attention_dp_rank():
def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
if dp_rank == 0:
local_start_pos = torch.zeros_like(cumtokens[0])
else:
local_start_pos = cumtokens[dp_rank - 1]
local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]
forward_batch.dp_local_start_pos = local_start_pos
forward_batch.dp_local_num_tokens = local_num_tokens
return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens
@triton.jit
def memcpy_triton_kernel(
dst_ptr,
src_ptr,
offset_ptr,
sz_ptr,
offset_src,
chunk_size, # multiplied for offset and sz
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0).to(tl.int64)
offset = tl.load(offset_ptr).to(tl.int64) * chunk_size
sz = tl.load(sz_ptr).to(tl.int64) * chunk_size
start_index = pid * BLOCK_SIZE
offs = tl.arange(0, BLOCK_SIZE)
mask = start_index + offs < sz
if offset_src:
data = tl.load(src_ptr + offset + start_index + offs, mask=mask)
tl.store(dst_ptr + start_index + offs, data, mask=mask)
else:
data = tl.load(src_ptr + start_index + offs, mask=mask)
tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)
def prod(x):
return functools.reduce(lambda a, b: a * b, x, 1)
def memcpy_triton(dst, src, dim, offset, sz, offset_src):
max_size = min(src.numel(), dst.numel())
assert dim == 0, "dim != 0 unsupported"
assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape"
chunk_size = prod(src.shape[1:])
BLOCK_SIZE = 8192
grid = (triton.cdiv(max_size, BLOCK_SIZE),)
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: Union[str, int],
):
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
global_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0 and (
layer_id != "embedding" or get_attention_tp_rank() == 0
):
assert (
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
)
# Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
NUM_GPUS_PER_NODE = 8
if (
not local_tokens.dtype.is_floating_point
and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE
):
torch.ops.sglang.inplace_all_reduce(
global_tokens, group_name=get_tp_group().unique_name
)
else:
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
def dp_scatter(
local_tokens: torch.Tensor, # output
global_tokens: torch.Tensor, # input
forward_batch: ForwardBatch,
):
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0:
assert (
local_tokens.untyped_storage().data_ptr()
!= global_tokens.untyped_storage().data_ptr()
), "aliasing between local_tokens and global_tokens not allowed"
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
def do_logits_dp_scatter(logits: torch.Tensor):
local_logits = torch.empty(
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
dtype=logits.dtype,
device=logits.device,
)
dp_scatter(local_logits, logits, forward_batch)
return local_logits
return do_logits_dp_scatter

View File

@@ -69,7 +69,7 @@ class RMSNorm(CustomOp):
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
x = (x * self.weight).to(orig_dtype)
if residual is None:
return x
else:

View File

@@ -426,13 +426,14 @@ class ColumnParallelLinear(LinearBase):
from sglang.srt.layers.parameter import _ColumnvLLMParameter
if isinstance(param, _ColumnvLLMParameter):
# FIXME: why would we need this special case?
param.load_column_parallel_weight(
loaded_weight,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
else:
# FIXME: This branch is needed to load deepseek v3 awq.
# However, we should fix this and avoid the branching here.
param.load_column_parallel_weight(loaded_weight)
def forward(self, input_):

View File

@@ -26,12 +26,19 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.dp_attention import (
dp_gather,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import dump_to_file
logger = logging.getLogger(__name__)
@@ -51,6 +58,9 @@ class LogitsProcessorOutput:
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None
next_token_top_logprobs_idx: Optional[List] = None
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
next_token_token_ids_logprobs_val: Optional[List] = None
next_token_token_ids_logprobs_idx: Optional[List] = None
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logprobs of input tokens. shape: [#token]
@@ -58,6 +68,9 @@ class LogitsProcessorOutput:
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
input_top_logprobs_val: List = None
input_top_logprobs_idx: List = None
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
input_token_ids_logprobs_val: Optional[List] = None
input_token_ids_logprobs_idx: Optional[List] = None
@dataclasses.dataclass
@@ -67,43 +80,107 @@ class LogitsMetadata:
extend_return_logprob: bool = False
extend_return_top_logprob: bool = False
extend_token_ids_logprob: bool = False
extend_seq_lens: Optional[torch.Tensor] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
top_logprobs_nums: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# logits and logprobs post processing
temp_scaled_logprobs: bool = False
temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False
top_p: torch.Tensor = None
# DP attention metadata. Not needed when DP attention is not used.
# Number of tokens in the request.
global_num_tokens_gpu: Optional[torch.Tensor] = None
# The start position of local hidden states.
dp_local_start_pos: Optional[torch.Tensor] = None
dp_local_num_tokens: Optional[torch.Tensor] = None
gathered_buffer: Optional[torch.Tensor] = None
# Buffer to gather logits from all ranks.
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
# Number of tokens to sample per DP rank
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# for padding
padded_static_len: int = -1
@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
extend_return_logprob = True
if (
forward_batch.forward_mode.is_extend()
and forward_batch.return_logprob
and not forward_batch.forward_mode.is_target_verify()
):
extend_return_top_logprob = any(
x > 0 for x in forward_batch.top_logprobs_nums
)
extend_logprob_pruned_lens_cpu = [
extend_len - start_len
for extend_len, start_len in zip(
forward_batch.extend_seq_lens_cpu,
forward_batch.extend_logprob_start_lens_cpu,
)
]
extend_token_ids_logprob = any(
x is not None for x in forward_batch.token_ids_logprobs
)
extend_return_logprob = False
extend_logprob_pruned_lens_cpu = []
for extend_len, start_len in zip(
forward_batch.extend_seq_lens_cpu,
forward_batch.extend_logprob_start_lens_cpu,
):
if extend_len - start_len > 0:
extend_return_logprob = True
extend_logprob_pruned_lens_cpu.append(extend_len - start_len)
else:
extend_return_logprob = extend_return_top_logprob = (
extend_logprob_pruned_lens_cpu
) = False
extend_token_ids_logprob
) = extend_logprob_pruned_lens_cpu = False
return cls(
forward_mode=forward_batch.forward_mode,
capture_hidden_mode=forward_batch.capture_hidden_mode,
extend_return_logprob=extend_return_logprob,
extend_return_top_logprob=extend_return_top_logprob,
extend_token_ids_logprob=extend_token_ids_logprob,
extend_seq_lens=forward_batch.extend_seq_lens,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
top_logprobs_nums=forward_batch.top_logprobs_nums,
token_ids_logprobs=forward_batch.token_ids_logprobs,
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
padded_static_len=forward_batch.padded_static_len,
)
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
if self.global_num_tokens_for_logprob_cpu is None:
# we are capturing cuda graph
return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_attention_dp_rank()
if dp_rank == 0:
dp_local_start_pos = torch.zeros_like(
self.global_num_tokens_for_logprob_gpu[0]
)
else:
dp_local_start_pos = cumtokens[dp_rank - 1]
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
gathered_buffer = torch.zeros(
(
sum(self.global_num_tokens_for_logprob_cpu),
hidden_states.shape[1],
),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
self.dp_local_start_pos = dp_local_start_pos
self.dp_local_num_tokens = dp_local_num_tokens
self.gathered_buffer = gathered_buffer
class LogitsProcessor(nn.Module):
def __init__(
@@ -115,6 +192,9 @@ class LogitsProcessor(nn.Module):
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)
self.do_tensor_parallel_all_gather_dp_attn = (
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
)
self.final_logit_softcapping = getattr(
self.config, "final_logit_softcapping", None
)
@@ -124,6 +204,12 @@ class LogitsProcessor(nn.Module):
):
self.final_logit_softcapping = None
from sglang.srt.managers.schedule_batch import global_server_args_dict
self.debug_tensor_dump_output_folder = global_server_args_dict[
"debug_tensor_dump_output_folder"
]
def forward(
self,
input_ids,
@@ -141,30 +227,74 @@ class LogitsProcessor(nn.Module):
):
pruned_states = hidden_states
sample_indices = None
input_logprob_indices = None
elif (
logits_metadata.forward_mode.is_extend()
and not logits_metadata.extend_return_logprob
):
# Prefill without input logprobs.
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
if logits_metadata.padded_static_len < 0:
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
else:
# If padding_static length is 5 and extended_seq_lens is [2, 3],
# then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p]
# and this retrieves t01 and t12, which are the valid last tokens
idx = torch.arange(
len(logits_metadata.extend_seq_lens),
device=logits_metadata.extend_seq_lens.device,
)
last_index = (
idx * logits_metadata.padded_static_len
+ logits_metadata.extend_seq_lens
- 1
)
pruned_states = hidden_states[last_index]
sample_indices = None
input_logprob_indices = None
else:
# Slice the requested tokens to compute logprob
# Input logprobs are required.
# Find 3 different indices.
# 1. pruned_states: hidden states that we want logprobs from.
# 2. sample_indices: Indices that have sampled tokens.
# 3. input_logprob_indices: Indices that have input logprob tokens.
sample_index_pt = -1
sample_indices = []
pt, pruned_states, pruned_input_ids = 0, [], []
for start_len, extend_len in zip(
input_logprob_indices_pt = 0
input_logprob_indices = []
pt, pruned_states = 0, []
for extend_logprob_start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
):
# It can happen in chunked prefill. We still need to sample 1 token,
# But we don't want to include it in input logprob.
if extend_len == extend_logprob_start_len:
start_len = extend_logprob_start_len - 1
else:
start_len = extend_logprob_start_len
# We always need at least 1 token to sample because that's required
# by a caller.
assert extend_len > start_len
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
pt += extend_len
sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt)
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len
input_logprob_indices.extend(
[
input_logprob_indices_pt + i
for i in range(extend_len - extend_logprob_start_len)
]
)
input_logprob_indices_pt += extend_len - start_len
pruned_states = torch.cat(pruned_states)
sample_indices = torch.tensor(
sample_indices, device=pruned_states.device, dtype=torch.int64
)
input_logprob_indices = torch.tensor(
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
)
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
@@ -172,28 +302,51 @@ class LogitsProcessor(nn.Module):
logits[sample_indices] if sample_indices is not None else logits
)
if (
not logits_metadata.extend_return_logprob
or logits_metadata.capture_hidden_mode.need_capture()
):
if self.debug_tensor_dump_output_folder:
assert (
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
), "dp attention + sharded lm_head doesn't support full logits"
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
hidden_states_to_store: Optional[torch.Tensor] = None
if logits_metadata.capture_hidden_mode.need_capture():
if logits_metadata.capture_hidden_mode.is_full():
hidden_states_to_store = hidden_states
elif logits_metadata.capture_hidden_mode.is_last():
# Get the last token hidden states. If sample_indices is None,
# pruned states only contain the last tokens already.
hidden_states_to_store = (
pruned_states[sample_indices] if sample_indices else pruned_states
)
else:
assert False, "Should never reach"
if not logits_metadata.extend_return_logprob:
# Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=(
hidden_states
if logits_metadata.capture_hidden_mode.is_full()
else (
pruned_states
if logits_metadata.capture_hidden_mode.is_last()
else None
)
),
hidden_states=hidden_states_to_store,
)
else:
input_logprobs = logits
input_logprobs = logits[input_logprob_indices]
del hidden_states, logits
# Normalize the logprob w/o temperature, top-p
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu,
device=input_logprobs.device,
)
if logits_metadata.temp_scaled_logprobs:
logits_metadata.temperature = torch.repeat_interleave(
logits_metadata.temperature.view(-1),
pruned_lens,
).view(-1, 1)
if logits_metadata.top_p_normalized_logprobs:
logits_metadata.top_p = torch.repeat_interleave(
logits_metadata.top_p,
pruned_lens,
)
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata
)
@@ -207,14 +360,18 @@ class LogitsProcessor(nn.Module):
else:
input_top_logprobs_val = input_top_logprobs_idx = None
# Get the logprob of given token id
if logits_metadata.extend_token_ids_logprob:
(
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
else:
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
torch.cat(
[
torch.cat(pruned_input_ids)[1:],
torch.tensor([0], device=input_logprobs.device),
]
),
logits_metadata.extend_input_logprob_token_ids_gpu,
]
return LogitsProcessorOutput(
@@ -222,6 +379,9 @@ class LogitsProcessor(nn.Module):
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
hidden_states=hidden_states_to_store,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
)
def _get_logits(
@@ -231,10 +391,24 @@ class LogitsProcessor(nn.Module):
logits_metadata: LogitsMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Get logits from hidden_states."""
"""Get logits from hidden_states.
If sampled_logits_only is True, it means hidden_states only contain the
last position (e.g., extend without input logprobs). The caller should
guarantee the given hidden_states follow this constraint.
"""
if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata(hidden_states)
hidden_states, local_hidden_states = (
logits_metadata.gathered_buffer,
hidden_states.clone(),
)
dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T)
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
)
else:
# GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
@@ -245,6 +419,17 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather:
logits = tensor_model_parallel_all_gather(logits)
if self.do_tensor_parallel_all_gather_dp_attn:
logits, global_logits = (
torch.empty(
(local_hidden_states.shape[0], logits.shape[1]),
device=logits.device,
dtype=logits.dtype,
),
logits,
)
dp_scatter(logits, global_logits, logits_metadata)
logits = logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping:
@@ -272,21 +457,66 @@ class LogitsProcessor(nn.Module):
continue
input_top_logprobs_val.append(
[values[pt + j][:k] for j in range(pruned_len - 1)]
[values[pt + j][:k] for j in range(pruned_len)]
)
input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len - 1)]
[indices[pt + j][:k] for j in range(pruned_len)]
)
pt += pruned_len
return input_top_logprobs_val, input_top_logprobs_idx
@staticmethod
def get_token_ids_logprobs(
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
):
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
pt = 0
for token_ids, pruned_len in zip(
logits_metadata.token_ids_logprobs,
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_token_ids_logprobs_val.append([])
input_token_ids_logprobs_idx.append([])
continue
input_token_ids_logprobs_val.append(
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
)
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
pt += pruned_len
return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
@staticmethod
def compute_temp_top_p_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
) -> torch.Tensor:
# TODO: Implement the temp and top-p normalization
return torch.nn.functional.log_softmax(last_logits, dim=-1)
"""
compute logprobs for the output token from the given logits.
Returns:
torch.Tensor: logprobs from logits
"""
# Scale logits if temperature scaling is enabled
if logits_metadata.temp_scaled_logprobs:
last_logits = last_logits / logits_metadata.temperature
# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
if (
logits_metadata.top_p_normalized_logprobs
and (logits_metadata.top_p != 1.0).any()
):
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
probs = torch.softmax(last_logits, dim=-1)
del last_logits
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
return torch.log(probs)
else:
return torch.nn.functional.log_softmax(last_logits, dim=-1)
@triton.jit

View File

@@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel(
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def gelu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# gelu & mul & quantize
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
# sqrt(2/pi)
kAlpha = 0.7978845608028654
gate_output = (
0.5
* gate_output
* (
1
+ tanh(
kAlpha
* (
gate_output
+ 0.044715 * gate_output * gate_output * gate_output
)
)
)
)
gate_output = gate_output.to(InDtype)
gelu_mul_output = gate_output * up_output * scale
gelu_mul_output = gelu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
@triton.jit
def post_reorder_triton_kernel(
down_output_ptr,

View File

@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
@@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module):
self.end_expert_id,
BLOCK_SIZE=512,
)
elif self.activation == "gelu":
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")

View File

@@ -24,6 +24,8 @@ def fused_moe_forward_native(
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,

View File

@@ -23,7 +23,7 @@ from sglang.srt.utils import (
is_hip,
)
is_hip_flag = is_hip()
is_hip_ = is_hip()
logger = logging.getLogger(__name__)
@@ -487,6 +487,7 @@ def invoke_fused_moe_kernel(
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@@ -646,7 +647,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2 if is_hip_flag else 4,
"num_stages": 2 if is_hip_ else 4,
}
if M <= E:
config = {
@@ -655,7 +656,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2 if is_hip_flag else 4,
"num_stages": 2 if is_hip_ else 4,
}
else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
@@ -665,7 +666,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2 if is_hip_flag else 3,
"num_stages": 2 if is_hip_ else 3,
}
else:
config = {
@@ -814,6 +815,7 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
@@ -831,6 +833,7 @@ def outplace_fused_experts(
a1_scale,
a2_scale,
block_shape,
no_combine=no_combine,
)
@@ -849,6 +852,7 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -877,8 +881,10 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
):
if inplace:
assert not no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts(
hidden_states,
w1,
@@ -912,6 +918,7 @@ def fused_experts(
a1_scale,
a2_scale,
block_shape,
no_combine=no_combine,
)
@@ -931,6 +938,7 @@ def fused_experts_impl(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
):
padded_size = padding_size
if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None:
@@ -987,7 +995,14 @@ def fused_experts_impl(
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
if inplace:
if no_combine:
assert not inplace
out_hidden_states = torch.empty(
(num_tokens, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
elif inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
@@ -1057,7 +1072,11 @@ def fused_experts_impl(
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
(
intermediate_cache3
if not no_combine and topk_ids.shape[1] != 1
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
),
a2_scale,
w2_scale,
curr_topk_weights,
@@ -1075,16 +1094,16 @@ def fused_experts_impl(
block_shape=block_shape,
)
if is_hip_flag:
if no_combine:
pass
elif is_hip_:
ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else:
if topk_ids.shape[1] == 1:
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
intermediate_cache3[:, 0]
)
pass # we write directly into out_hidden_states
elif topk_ids.shape[1] == 2:
torch.add(
intermediate_cache3[:, 0],
@@ -1122,6 +1141,7 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1191,4 +1211,5 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
no_combine=no_combine,
)

View File

@@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
return self.forward(
x=x,
@@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
inplace=inplace,
no_combine=no_combine,
)
def forward_cuda(
@@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
@@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
from aiter.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported."
assert not no_combine, "unsupported"
return fused_experts_ck(
hidden_states=x,
@@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=inplace and not no_combine,
activation=activation,
no_combine=no_combine,
)
def forward_cpu(
@@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
inplace: bool = True,
) -> torch.Tensor:
return moe_forward_native(
layer,
@@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module):
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
inplace: suggestion to compute inplace (modify input activation).
"""
def __init__(
@@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module):
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
):
super().__init__()
@@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module):
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.activation = activation
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
@@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module):
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
self.use_presharded_weights = use_presharded_weights
def _load_per_tensor_weight_scale(
self,
@@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module):
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,
inplace=self.inplace,
no_combine=self.no_combine,
)
if self.reduce_results and self.tp_size > 1:

View File

@@ -771,6 +771,8 @@ class Fp8MoEMethod:
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
@@ -793,6 +795,7 @@ class Fp8MoEMethod:
from aiter.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported."
assert not no_combine, f"{no_combine=} is not supported."
return fused_experts_ck(
x,
@@ -823,7 +826,7 @@ class Fp8MoEMethod:
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=inplace and not no_combine,
activation=activation,
use_fp8_w8a8=True,
w1_scale=(
@@ -839,6 +842,7 @@ class Fp8MoEMethod:
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
)

View File

@@ -707,7 +707,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1)
print("Cache shape", cache.shape)
return cache
def forward(

View File

@@ -1,5 +1,5 @@
import logging
from typing import List
from typing import List, Optional
import torch
import torch.distributed as dist
@@ -41,7 +41,21 @@ class Sampler(nn.Module):
sampling_info: SamplingBatchInfo,
return_logprob: bool,
top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
batch_next_token_ids: Optional[torch.Tensor] = None,
):
"""Run a sampler & compute logprobs and update logits_output accordingly.
Args:
logits_output: The logits from the model forward
sampling_info: Metadata for sampling
return_logprob: If set, store the output logprob information to
logits_output
top_logprobs_nums: Number of top lobprobs per sequence in a batch
batch_next_token_ids: next token IDs. If set, skip sampling and only
compute output logprobs It is used for speculative decoding which
performs sampling in draft workers.
"""
logits = logits_output.next_token_logits
# Apply the custom logit processors if registered in the sampling info.
@@ -58,13 +72,15 @@ class Sampler(nn.Module):
if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
batch_next_token_ids = torch.argmax(logits, -1)
if batch_next_token_ids is None:
batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else:
# Post process logits
logits.div_(sampling_info.temperatures)
probs = torch.softmax(logits, dim=-1)
logits[:] = torch.softmax(logits, dim=-1)
probs = logits
del logits
if global_server_args_dict["sampling_backend"] == "flashinfer":
@@ -78,38 +94,43 @@ class Sampler(nn.Module):
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
).clamp(min=torch.finfo(probs.dtype).min)
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
if batch_next_token_ids is None:
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if self.use_nan_detection and not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
if self.use_nan_detection and not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(
batch_next_token_ids
)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
sampling_info.top_ks,
sampling_info.top_ps,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
)
if batch_next_token_ids is None:
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
sampling_info.top_ks,
sampling_info.top_ps,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
)
if return_logprob:
# clamp to avoid -inf
logprobs = torch.log(
@@ -128,6 +149,12 @@ class Sampler(nn.Module):
logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums)
if any(x is not None for x in token_ids_logprobs):
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs(logprobs, token_ids_logprobs)
logits_output.next_token_logprobs = logprobs[
torch.arange(len(batch_next_token_ids), device=sampling_info.device),
batch_next_token_ids,
@@ -223,6 +250,10 @@ def top_p_normalize_probs_torch(
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
assert len(top_logprobs_nums) == logprobs.shape[0], (
len(top_logprobs_nums),
logprobs.shape[0],
)
max_k = max(top_logprobs_nums)
ret = logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
@@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
output_top_logprobs_val.append(values[i][:k])
output_top_logprobs_idx.append(indices[i][:k])
return output_top_logprobs_val, output_top_logprobs_idx
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
output_token_ids_logprobs_val = []
output_token_ids_logprobs_idx = []
for i, token_ids in enumerate(token_ids_logprobs):
if token_ids is not None:
output_token_ids_logprobs_val.append(logprobs[i, token_ids].tolist())
output_token_ids_logprobs_idx.append(token_ids)
else:
output_token_ids_logprobs_val.append([])
output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx

View File

@@ -457,7 +457,7 @@ class VocabParallelEmbedding(torch.nn.Module):
assert loaded_weight.shape[output_dim] == (
self.org_vocab_size
// (self.tp_size if self.use_presharded_weights else 1)
)
), f"{self.org_vocab_size=} {self.use_presharded_weights=} {loaded_weight.shape[output_dim]=}"
# Copy the data.
if not self.use_presharded_weights:

View File

@@ -28,6 +28,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000")
parser.add_argument("--log-requests", action="store_true")
parser.add_argument("--log-requests-level", type=int, default=2)
parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
)
@@ -38,7 +39,7 @@ if __name__ == "__main__":
args.url + "/configure_logging",
json={
"log_requests": args.log_requests,
"log_requests_level": 1, # Log full requests
"log_requests_level": args.log_requests_level, # Log full requests
"dump_requests_folder": args.dump_requests_folder,
"dump_requests_threshold": args.dump_requests_threshold,
},

View File

@@ -198,6 +198,8 @@ class DataParallelController:
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
print(f"{scheduler_info=}")
def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
@@ -220,6 +222,7 @@ class DataParallelController:
TokenizedEmbeddingReqInput,
),
):
logger.info("dispatching")
self.dispatching(recv_req)
else:
# Send other control messages to first worker of tp group

View File

@@ -14,6 +14,7 @@
"""DetokenizerManager is a process that detokenizes the token ids."""
import dataclasses
import json
import logging
import os
import signal
@@ -27,11 +28,16 @@ import zmq
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchMultimodalDecodeReq,
BatchStrOut,
BatchTokenIDOut,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, get_zmq_socket
from sglang.srt.utils import (
configure_logger,
get_zmq_socket,
kill_itself_when_parent_died,
)
from sglang.utils import (
TypeBasedDispatcher,
find_printable_text,
@@ -86,14 +92,23 @@ class DetokenizerManager:
)
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
self.is_dummy = server_args.load_format == "dummy"
self._request_dispatcher = TypeBasedDispatcher(
[
(BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
]
)
def event_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
self.send_to_tokenizer.send_pyobj(output)
def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
):
@@ -117,14 +132,6 @@ class DetokenizerManager:
return output[:-1]
return output
def event_loop(self):
"""The event loop that handles requests"""
while True:
recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj)
self.send_to_tokenizer.send_pyobj(output)
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
# If it is embedding model, no detokenization is needed.
return recv_obj
@@ -173,7 +180,6 @@ class DetokenizerManager:
# Incremental decoding
output_strs = []
finished_reqs = []
for i in range(bs):
try:
s = self.decode_status[recv_obj.rids[i]]
@@ -196,8 +202,6 @@ class DetokenizerManager:
new_text = ""
else:
new_text = find_printable_text(new_text)
else:
finished_reqs.append(recv_obj.rids[i])
output_strs.append(
self.trim_matched_stop(
@@ -207,7 +211,7 @@ class DetokenizerManager:
)
)
out = BatchStrOut(
return BatchStrOut(
rids=recv_obj.rids,
finished_reasons=recv_obj.finished_reasons,
output_strs=output_strs,
@@ -223,14 +227,15 @@ class DetokenizerManager:
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_hidden_states=recv_obj.output_hidden_states,
)
# remove decodestatus for completed requests
for rid in finished_reqs:
self.decode_status.pop(rid)
return out
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
raise NotImplementedError()
class LimitedCapacityDict(OrderedDict):
@@ -250,6 +255,7 @@ def run_detokenizer_process(
server_args: ServerArgs,
port_args: PortArgs,
):
kill_itself_when_parent_died()
setproctitle.setproctitle("sglang::detokenizer")
configure_logger(server_args)
parent_process = psutil.Process().parent()

View File

@@ -16,10 +16,11 @@ The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""
import copy
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -55,6 +56,8 @@ class GenerateReqInput:
logprob_start_len: Optional[Union[List[int], int]] = None
# If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: Optional[Union[List[int], int]] = None
# If return logprobs, the token ids to return logprob for.
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
# Whether to stream output.
@@ -146,6 +149,8 @@ class GenerateReqInput:
self.logprob_start_len = -1
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = None
else:
if self.parallel_sample_num == 1:
num = self.batch_size
@@ -191,6 +196,17 @@ class GenerateReqInput:
else:
assert self.parallel_sample_num == 1
if not self.token_ids_logprob: # covers both None and []
self.token_ids_logprob = [None] * num
elif not isinstance(self.token_ids_logprob, list):
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
elif not isinstance(self.token_ids_logprob[0], list):
self.token_ids_logprob = [
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
]
else:
assert self.parallel_sample_num == 1
if self.custom_logit_processor is None:
self.custom_logit_processor = [None] * num
elif not isinstance(self.custom_logit_processor, list):
@@ -198,6 +214,12 @@ class GenerateReqInput:
else:
assert self.parallel_sample_num == 1
# Other checks
if self.session_params is not None:
assert isinstance(self.session_params, dict) or isinstance(
self.session_params[0], dict
)
def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
@@ -212,6 +234,7 @@ class GenerateReqInput:
return_logprob=self.return_logprob[i],
logprob_start_len=self.logprob_start_len[i],
top_logprobs_num=self.top_logprobs_num[i],
token_ids_logprob=self.token_ids_logprob[i],
return_text_in_logprobs=self.return_text_in_logprobs,
stream=self.stream,
log_metrics=self.log_metrics,
@@ -244,6 +267,8 @@ class TokenizedGenerateReqInput:
logprob_start_len: int
# If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: int
# If return logprobs, the token id to return logprob for
token_ids_logprob: List[int]
# Whether to stream output
stream: bool
@@ -378,10 +403,21 @@ class BatchTokenIDOut:
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
input_token_ids_logprobs_val: List[List]
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
# Hidden states
output_hidden_states: List[List[float]]
@dataclass
class BatchMultimodalDecodeReq:
# The request id
rids: List[str]
@dataclass
class BatchStrOut:
# The request id
@@ -406,10 +442,21 @@ class BatchStrOut:
input_top_logprobs_idx: List[List]
output_top_logprobs_val: List[List]
output_top_logprobs_idx: List[List]
input_token_ids_logprobs_val: List[List]
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
# Hidden states
output_hidden_states: List[List[float]]
@dataclass
class BatchMultimodalOut:
# The request id
rids: List[str]
@dataclass
class BatchEmbeddingOut:
# The request id
@@ -439,6 +486,8 @@ class UpdateWeightFromDiskReqInput:
class UpdateWeightFromDiskReqOutput:
success: bool
message: str
# Number of paused requests during weight sync.
num_paused_requests: Optional[int] = 0
@dataclass
@@ -526,11 +575,57 @@ class AbortReq:
rid: str
class ProfileReq(Enum):
@dataclass
class GetInternalStateReq:
pass
@dataclass
class GetInternalStateReqOutput:
internal_state: Dict[Any, Any]
@dataclass
class SetInternalStateReq:
server_args: Dict[str, Any]
@dataclass
class SetInternalStateReqOutput:
updated: bool
server_args: Dict[str, Any]
@dataclass
class ProfileReqInput:
# The output directory
output_dir: Optional[str] = None
# If set, it profile as many as this number of steps.
# If it is set, profiling is automatically stopped after this step, and
# the caller doesn't need to run stop_profile.
num_steps: Optional[int] = None
activities: Optional[List[str]] = None
class ProfileReqType(Enum):
START_PROFILE = 1
STOP_PROFILE = 2
@dataclass
class ProfileReq:
type: ProfileReqType
output_dir: Optional[str] = None
num_steps: Optional[int] = None
activities: Optional[List[str]] = None
@dataclass
class ProfileReqOutput:
success: bool
message: str
@dataclass
class ConfigureLoggingReq:
log_requests: Optional[bool] = None
@@ -556,6 +651,11 @@ class OpenSessionReqOutput:
success: bool
@dataclass
class HealthCheckOutput:
pass
@dataclass
class Function:
description: Optional[str] = None

View File

@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
It contains low-level tensor data. Most of the data consists of GPU tensors.
"""
import copy
import dataclasses
import logging
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
@@ -50,7 +51,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -65,6 +69,8 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
@@ -230,6 +236,7 @@ class Req:
sampling_params: SamplingParams,
return_logprob: bool = False,
top_logprobs_num: int = 0,
token_ids_logprob: List[int] = None,
stream: bool = False,
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None,
@@ -256,17 +263,24 @@ class Req:
self.input_embeds = input_embeds
# Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
sampling_params.custom_params = sampling_params.custom_params | {
"__req__": self
}
self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
# Memory pool info
self.req_pool_idx = None
self.req_pool_idx: Optional[int] = None
# Check finish
self.tokenizer = None
self.finished_reason = None
# If we want to abort the request in the middle of the event loop, set this to true
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self.to_abort = False
self.stream = stream
self.eos_token_ids = eos_token_ids
@@ -289,38 +303,56 @@ class Req:
self.image_inputs: Optional[ImageInputs] = None
# Prefix info
# The indices to kv cache for the shared prefix.
self.prefix_indices = []
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
# Updated if chunked.
# Number of tokens to run prefill.
self.extend_input_len = 0
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
self.last_node = None
# Chunked prefill
self.is_being_chunked = 0
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
# processed.
self.is_chunked = 0
# For retraction
self.is_retracted = False
# Logprobs (arguments)
self.return_logprob = return_logprob
# Start index to compute logprob from.
self.logprob_start_len = 0
self.top_logprobs_num = top_logprobs_num
self.token_ids_logprob = token_ids_logprob
# Logprobs (return values)
self.input_token_logprobs_val: Optional[List[float]] = None
self.input_token_logprobs_idx: Optional[List[int]] = None
self.input_top_logprobs_val: Optional[List[float]] = None
self.input_top_logprobs_idx: Optional[List[int]] = None
self.input_token_ids_logprobs_val: Optional[List[float]] = None
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
# Temporary holder to store input_token_logprobs.
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
if return_logprob:
self.output_token_logprobs_val = []
self.output_token_logprobs_idx = []
self.output_top_logprobs_val = []
self.output_top_logprobs_idx = []
self.output_token_ids_logprobs_val = []
self.output_token_ids_logprobs_idx = []
else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
self.output_top_logprobs_val
) = self.output_top_logprobs_idx = None
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
self.output_token_ids_logprobs_idx
) = None
self.hidden_states = []
# Logprobs (internal values)
@@ -345,6 +377,13 @@ class Req:
self.spec_verify_ct = 0
self.lora_path = lora_path
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self.to_abort_message: str = "Unknown error"
@property
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
def extend_image_inputs(self, image_inputs):
if self.image_inputs is None:
self.image_inputs = image_inputs
@@ -422,7 +461,9 @@ class Req:
return
if self.to_abort:
self.finished_reason = FINISH_ABORT()
self.finished_reason = FINISH_ABORT(
message=self.to_abort_message,
)
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
@@ -517,6 +558,8 @@ class Req:
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k
@@ -527,16 +570,19 @@ class Req:
self.last_node = None
self.extend_input_len = 0
self.is_retracted = True
self.input_token_logprobs = None
self.temp_input_top_logprobs_val = None
self.temp_input_top_logprobs_idx = None
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
# For incremental logprobs
# TODO: Fix the `logprob_start_len`
self.last_update_decode_tokens = 0
self.logprob_start_len = 10**9
def __repr__(self):
return (
f"rid(n={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
f"Req(rid={self.rid}, "
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
)
@@ -576,11 +622,13 @@ class ScheduleBatch:
# For DP attention
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
@@ -588,6 +636,8 @@ class ScheduleBatch:
extend_num_tokens: int = None
decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None
# It comes empty list if logprob is not required.
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
# For encoder-decoder
encoder_cached: Optional[List[bool]] = None
@@ -606,7 +656,7 @@ class ScheduleBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
# Enable custom logit processor
enable_custom_logit_processor: bool = False
@@ -653,8 +703,10 @@ class ScheduleBatch:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"Out of memory. "
"Please set a smaller number for `--max-running-requests`."
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
@@ -765,6 +817,7 @@ class ScheduleBatch:
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
input_embeds = []
extend_input_logprob_token_ids = []
pt = 0
for i, req in enumerate(reqs):
@@ -783,22 +836,64 @@ class ScheduleBatch:
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
if req.return_logprob:
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
extend_logprob_start_len = min(
req.logprob_start_len - pre_len, req.extend_input_len - 1
)
else:
raise RuntimeError(
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
)
req.extend_logprob_start_len = extend_logprob_start_len
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
pre_lens.append(pre_len)
# Compute the relative logprob_start_len in an extend batch
if req.logprob_start_len >= pre_len:
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
else:
req.extend_logprob_start_len = 0
if self.return_logprob:
# Find input logprob token ids.
# First, find a global index within origin_input_ids and slide it by 1
# to compute input logprobs. It is because you need the next token
# to compute input logprobs. E.g., (chunk size 2)
#
# input_logprobs = [1, 2, 3, 4]
# fill_ids = [1, 2]
# extend_input_logprob_token_id = [2, 3]
#
# Note that it can also overflow. In this case, we pad it with 0.
# input_logprobs = [1, 2, 3, 4]
# fill_ids = [3, 4]
# extend_input_logprob_token_id = [4, 0]
global_start_idx, global_end_idx = (
len(req.prefix_indices),
len(req.fill_ids),
)
# Apply logprob_start_len
if global_start_idx < req.logprob_start_len:
global_start_idx = req.logprob_start_len
logprob_token_ids = req.origin_input_ids[
global_start_idx + 1 : global_end_idx + 1
]
extend_input_logprob_token_ids.extend(logprob_token_ids)
# We will need req.extend_input_len - req.extend_logprob_start_len number of
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
extend_input_logprob_token_ids.extend(
[0]
* (
req.extend_input_len
- req.extend_logprob_start_len
- len(logprob_token_ids)
)
)
if self.return_logprob:
extend_input_logprob_token_ids = torch.tensor(
extend_input_logprob_token_ids
)
else:
extend_input_logprob_token_ids = None
# Set fields
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
@@ -821,10 +916,12 @@ class ScheduleBatch:
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
# Write to req_to_token_pool
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
@@ -860,7 +957,6 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
enable_overlap_schedule=self.enable_overlap,
)
def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -905,25 +1001,43 @@ class ScheduleBatch:
return False
def retract_decode(self):
def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve retraction policy for radix cache
sorted_indices.sort(
key=lambda i: (
len(self.reqs[i].output_ids),
-len(self.reqs[i].origin_input_ids),
),
reverse=True,
)
# For spec decoding, filter_batch API can only filter
# requests from the back, so we can only retract from the back.
# TODO(sang): Clean up finish path and support better retract
# policy.
if not server_args.speculative_algorithm:
sorted_indices.sort(
key=lambda i: (
len(self.reqs[i].output_ids),
-len(self.reqs[i].origin_input_ids),
),
reverse=True,
)
def get_required_tokens(num_reqs: int):
headroom_for_spec_decode = 0
if server_args.speculative_algorithm:
headroom_for_spec_decode += (
num_reqs
* server_args.speculative_eagle_topk
* server_args.speculative_num_steps
+ num_reqs * server_args.speculative_num_draft_tokens
)
return (
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
self.token_to_kv_pool.available_size()
< len(sorted_indices) * global_config.retract_decode_steps
< get_required_tokens(len(sorted_indices))
or first_iter
):
if len(sorted_indices) == 1:
@@ -1048,17 +1162,40 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
enable_overlap_schedule=self.enable_overlap,
)
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
if self.spec_algorithm.is_eagle():
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
return
if self.sampling_info.penalizer_orchestrator.is_required:
if self.enable_overlap:
# TODO: this can be slow, optimize this.
delayed_output_ids = torch.tensor(
[
(
req.output_ids[-1]
if len(req.output_ids)
else req.origin_input_ids[-1]
)
for req in self.reqs
],
dtype=torch.int64,
device=self.device,
)
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
delayed_output_ids
)
else:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
self.output_ids.to(torch.int64)
)
self.input_ids = self.output_ids
self.output_ids = None
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
# Alloc mem
bs = len(self.reqs)
@@ -1086,14 +1223,15 @@ class ScheduleBatch:
def filter_batch(
self,
being_chunked_req: Optional[Req] = None,
chunked_req_to_exclude: Optional[Req] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
if not self.reqs[i].finished()
and self.reqs[i] is not chunked_req_to_exclude
]
if keep_indices is None or len(keep_indices) == 0:
@@ -1105,31 +1243,34 @@ class ScheduleBatch:
# No need to filter
return
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
if self.model_config.is_encoder_decoder:
self.encoder_lens = self.encoder_lens[keep_indices]
self.encoder_lens = self.encoder_lens[keep_indices_device]
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices]
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
self.device, non_blocking=True
)
self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[new_indices]
self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
else:
self.top_logprobs_nums = None
self.token_ids_logprobs = None
self.has_stream = any(req.stream for req in self.reqs)
self.has_grammar = any(req.grammar for req in self.reqs)
self.sampling_info.filter_batch(keep_indices, new_indices)
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
if self.spec_info:
self.spec_info.filter_batch(new_indices)
self.spec_info.filter_batch(keep_indices_device)
def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1152,10 +1293,13 @@ class ScheduleBatch:
self.output_ids = torch.concat([self.output_ids, other.output_ids])
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.token_ids_logprobs.extend(other.token_ids_logprobs)
elif self.return_logprob:
self.top_logprobs_nums.extend([0] * len(other.reqs))
self.token_ids_logprobs.extend([None] * len(other.reqs))
elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
self.reqs.extend(other.reqs)
self.return_logprob |= other.return_logprob
@@ -1192,7 +1336,9 @@ class ScheduleBatch:
seq_lens_sum=self.seq_lens_sum,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
token_ids_logprobs=self.token_ids_logprobs,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
@@ -1219,6 +1365,7 @@ class ScheduleBatch:
else CaptureHiddenMode.NULL
)
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
)
def copy(self):
@@ -1262,9 +1409,11 @@ class ModelWorkerBatch:
# For logprob
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
token_ids_logprobs: Optional[List[List[int]]]
# For DP attention
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For extend
@@ -1272,6 +1421,7 @@ class ModelWorkerBatch:
extend_seq_lens: Optional[List[int]]
extend_prefix_lens: Optional[List[int]]
extend_logprob_start_lens: Optional[List[int]]
extend_input_logprob_token_ids: Optional[torch.Tensor]
# For multimodal
image_inputs: Optional[List[ImageInputs]]
@@ -1293,7 +1443,8 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None

View File

@@ -272,7 +272,7 @@ class PrefillAdder:
self.req_states = None
self.can_run_list = []
self.new_being_chunked_req = None
self.new_chunked_req = None
self.log_hit_tokens = 0
self.log_input_tokens = 0
@@ -327,7 +327,7 @@ class PrefillAdder:
self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len
def add_being_chunked_req(self, req: Req):
def add_chunked_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -354,7 +354,7 @@ class PrefillAdder:
finally:
self.tree_cache.dec_lock_ref(last_node)
def add_one_req_ignore_eos(self, req: Req):
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
def add_req_state(r, insert_sort=False):
new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
@@ -403,6 +403,7 @@ class PrefillAdder:
self.rem_chunk_tokens is None
or req.extend_input_len <= self.rem_chunk_tokens
):
# Non-chunked prefill
self.can_run_list.append(req)
self._prefill_one_req(
0,
@@ -418,14 +419,14 @@ class PrefillAdder:
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_being_chunked_req = req
self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0)
return self.budget_state()
def add_one_req(self, req: Req):
def add_one_req(self, req: Req, has_chunked_req: bool):
if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req)
return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
@@ -443,14 +444,7 @@ class PrefillAdder:
if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN
if (
self.rem_chunk_tokens is None
or input_tokens <= self.rem_chunk_tokens
or (
req.return_logprob
and req.logprob_start_len != len(req.origin_input_ids) - 1
)
):
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill
self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
@@ -470,8 +464,9 @@ class PrefillAdder:
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req)
self.new_being_chunked_req = req
self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)

File diff suppressed because it is too large Load Diff

View File

@@ -35,12 +35,12 @@ class SessionReqNode:
for req_node in self.childs:
req_node.clear(req_dict)
if self.req.finished_reason == None:
if self.req.finished_reason is None:
self.req.to_abort = True
del req_dict[self.req.rid]
def abort(self):
if self.req.finished_reason == None:
if self.req.finished_reason is None:
self.req.to_abort = True
def __str__(self):
@@ -132,6 +132,10 @@ class Session:
lora_path=req.lora_path,
session_id=self.session_id,
custom_logit_processor=req.custom_logit_processor,
stream=req.stream,
return_logprob=req.return_logprob,
top_logprobs_num=req.top_logprobs_num,
token_ids_logprob=req.token_ids_logprob,
)
if last_req is not None:
new_req.image_inputs = last_req.image_inputs

View File

@@ -16,6 +16,7 @@
import asyncio
import copy
import dataclasses
import json
import logging
import os
import pickle
@@ -24,9 +25,21 @@ import sys
import threading
import time
import uuid
from collections import deque
from datetime import datetime
from http import HTTPStatus
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Awaitable,
Deque,
Dict,
Generic,
List,
Optional,
Tuple,
TypeVar,
Union,
)
import fastapi
import uvloop
@@ -44,6 +57,7 @@ from sglang.srt.managers.image_processor import (
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
CloseSessionReqInput,
@@ -51,18 +65,25 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
GetInternalStateReq,
GetInternalStateReqOutput,
GetWeightsByNameReqInput,
GetWeightsByNameReqOutput,
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
ReleaseMemoryOccupationReqInput,
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SessionParams,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
@@ -98,7 +119,10 @@ class ReqState:
# For metrics
created_time: float
first_token_time: Optional[float] = None
finished_time: float = 0.0
first_token_time: float = 0.0
last_time: float = 0.0
last_completion_tokens: int = 1
# For streaming output
last_output_offset: int = 0
@@ -113,11 +137,10 @@ class TokenizerManager:
port_args: PortArgs,
):
# Parse args
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests
self.log_requests_level = 0
self.log_requests_level = server_args.log_requests_level
# Init inter-process communication
context = zmq.asyncio.Context(2)
@@ -143,6 +166,7 @@ class TokenizerManager:
)
self.is_generation = self.model_config.is_generation
self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
@@ -178,9 +202,12 @@ class TokenizerManager:
# Store states
self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
self.gracefully_exit = False
self.last_receive_tstamp = 0
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
self.log_request_metadata = self.get_log_request_metadata()
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
@@ -192,8 +219,19 @@ class TokenizerManager:
# For session info
self.session_futures = {} # session_id -> asyncio event
# Others
self.gracefully_exit = False
# Set after scheduler is initialized
self.max_req_input_len = None
# Metrics
if self.enable_metrics:
self.metrics_collector = TokenizerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
# Communicators
self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -212,22 +250,26 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
# Set after scheduler is initialized
self.max_req_input_len = None
# Metrics
if self.enable_metrics:
self.metrics_collector = TokenizerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
self.start_profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher(
[
(
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
(
BatchStrOut,
BatchEmbeddingOut,
BatchTokenIDOut,
BatchMultimodalOut,
),
self._handle_batch_output,
),
(OpenSessionReqOutput, self._handle_open_session_req_output),
@@ -259,6 +301,19 @@ class TokenizerManager:
ResumeMemoryOccupationReqOutput,
self.resume_memory_occupation_communicator.handle_recv,
),
(
ProfileReqOutput,
self.start_profile_communicator.handle_recv,
),
(
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None),
]
)
@@ -280,9 +335,9 @@ class TokenizerManager:
obj.normalize_batch_and_arguments()
if self.log_requests:
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
max_length, skip_names, _ = self.log_request_metadata
logger.info(
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
)
async with self.model_update_lock.reader_lock:
@@ -336,6 +391,7 @@ class TokenizerManager:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
token_ids_logprob = obj.token_ids_logprob
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
@@ -378,6 +434,7 @@ class TokenizerManager:
return_logprob,
logprob_start_len,
top_logprobs_num,
token_ids_logprob,
obj.stream,
lora_path=obj.lora_path,
input_embeds=input_embeds,
@@ -401,8 +458,7 @@ class TokenizerManager:
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
created_time: Optional[float] = None,
):
event = asyncio.Event()
state = ReqState([], False, event, obj, created_time=created_time)
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
self.rid_to_state[obj.rid] = state
self.send_to_scheduler.send_pyobj(tokenized_obj)
@@ -420,7 +476,10 @@ class TokenizerManager:
except asyncio.TimeoutError:
if request is not None and await request.is_disconnected():
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
raise ValueError(
"Request is disconnected from the client side. "
f"Abort request {obj.rid}"
)
continue
out = state.out_list[-1]
@@ -428,8 +487,11 @@ class TokenizerManager:
state.out_list = []
if state.finished:
if self.log_requests:
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
max_length, skip_names, out_skip_names = self.log_request_metadata
if self.model_config.is_multimodal_gen:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
else:
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
logger.info(msg)
del self.rid_to_state[obj.rid]
@@ -452,7 +514,10 @@ class TokenizerManager:
else:
if request is not None and await request.is_disconnected():
self.abort_request(obj.rid)
raise ValueError(f"Abort request {obj.rid}")
raise ValueError(
"Request is disconnected from the client side. "
f"Abort request {obj.rid}"
)
async def _handle_batch_request(
self,
@@ -543,12 +608,25 @@ class TokenizerManager:
req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req)
def start_profile(self):
req = ProfileReq.START_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def start_profile(
self,
output_dir: Optional[str] = None,
num_steps: Optional[int] = None,
activities: Optional[List[str]] = None,
):
req = ProfileReq(
type=ProfileReqType.START_PROFILE,
output_dir=output_dir,
num_steps=num_steps,
activities=activities,
)
result = (await self.start_profile_communicator(req))[0]
if not result.success:
raise RuntimeError(result.message)
return result
def stop_profile(self):
req = ProfileReq.STOP_PROFILE
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
self.send_to_scheduler.send_pyobj(req)
async def update_weights_from_disk(
@@ -581,7 +659,7 @@ class TokenizerManager:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
return result.success, result.message, result.num_paused_requests
else: # self.server_args.dp_size > 1
self.model_update_tmp = []
result = await self.model_update_result
@@ -593,7 +671,8 @@ class TokenizerManager:
self.model_path = obj.model_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)
return all_success, all_message
all_paused_requests = [r.num_paused_requests for r in result]
return all_success, all_message, all_paused_requests
async def init_weights_update_group(
self,
@@ -688,6 +767,54 @@ class TokenizerManager:
):
await self.send_to_scheduler.send_pyobj(obj)
async def get_internal_state(self) -> Dict[Any, Any]:
req = GetInternalStateReq()
res: List[GetInternalStateReqOutput] = (
await self.get_internal_state_communicator(req)
)
return res[0].internal_state
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
res: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return res[0]
def get_log_request_metadata(self):
max_length = None
skip_names = None
out_skip_names = None
if self.log_requests:
if self.log_requests_level == 0:
max_length = 1 << 30
skip_names = set(
[
"text",
"input_ids",
"input_embeds",
"image_data",
"audio_data",
"lora_path",
]
)
out_skip_names = set(
[
"text",
"output_ids",
]
)
elif self.log_requests_level == 1:
max_length = 2048
elif self.log_requests_level == 2:
max_length = 1 << 30
else:
raise ValueError(
f"Invalid --log-requests-level: {self.log_requests_level=}"
)
return max_length, skip_names, out_skip_names
def configure_logging(self, obj: ConfigureLoggingReq):
if obj.log_requests is not None:
self.log_requests = obj.log_requests
@@ -698,6 +825,7 @@ class TokenizerManager:
if obj.dump_requests_threshold is not None:
self.dump_requests_threshold = obj.dump_requests_threshold
logging.info(f"Config logging: {obj=}")
self.log_request_metadata = self.get_log_request_metadata()
def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected.
@@ -762,15 +890,20 @@ class TokenizerManager:
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
self._result_dispatcher(recv_obj)
self.last_receive_tstamp = time.time()
def _handle_batch_output(
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
self,
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
],
):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue
# Build meta_info and return value
meta_info = {
"id": rid,
"finish_reason": recv_obj.finished_reasons[i],
@@ -781,14 +914,12 @@ class TokenizerManager:
self.convert_logprob_style(
meta_info,
state.obj.top_logprobs_num,
state.obj.token_ids_logprob,
state.obj.return_text_in_logprobs,
recv_obj,
i,
)
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
if not isinstance(recv_obj, BatchEmbeddingOut):
meta_info.update(
{
@@ -806,10 +937,20 @@ class TokenizerManager:
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchTokenIDOut):
if self.server_args.stream_output and state.obj.stream:
output_token_ids = recv_obj.output_ids[i][
state.last_output_offset :
]
state.last_output_offset = len(recv_obj.output_ids[i])
else:
output_token_ids = recv_obj.output_ids[i]
out_dict = {
"token_ids": recv_obj.output_ids[i],
"output_ids": output_token_ids,
"meta_info": meta_info,
}
elif isinstance(recv_obj, BatchMultimodalOut):
raise NotImplementedError()
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
@@ -817,10 +958,17 @@ class TokenizerManager:
"meta_info": meta_info,
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reasons[i] is not None
if state.finished:
if self.server_args.speculative_algorithm:
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
state.finished_time = time.time()
meta_info["e2e_latency"] = state.finished_time - state.created_time
state.out_list.append(out_dict)
state.event.set()
# Log metrics and dump
if self.enable_metrics and state.obj.log_metrics:
self.collect_metrics(state, recv_obj, i)
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
@@ -830,6 +978,7 @@ class TokenizerManager:
self,
meta_info: dict,
top_logprobs_num: int,
token_ids_logprob: List[int],
return_text_in_logprobs: bool,
recv_obj: BatchStrOut,
recv_obj_index: int,
@@ -857,6 +1006,20 @@ class TokenizerManager:
return_text_in_logprobs,
)
if token_ids_logprob is not None:
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
recv_obj.input_token_ids_logprobs_val[recv_obj_index],
recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
meta_info["output_token_ids_logprobs"] = (
self.detokenize_top_logprobs_tokens(
recv_obj.output_token_ids_logprobs_val[recv_obj_index],
recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
return_text_in_logprobs,
)
)
def detokenize_logprob_tokens(
self,
token_logprobs_val: List[float],
@@ -900,34 +1063,30 @@ class TokenizerManager:
else 0
)
if state.first_token_time is None:
state.first_token_time = time.time()
if state.first_token_time == 0.0:
state.first_token_time = state.last_time = time.time()
state.last_completion_tokens = completion_tokens
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
# Compute time_per_output_token for the streaming case
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time) / (completion_tokens - 1)
num_new_tokens = completion_tokens - state.last_completion_tokens
if num_new_tokens:
new_time = time.time()
interval = new_time - state.last_time
self.metrics_collector.observe_inter_token_latency(
interval,
num_new_tokens,
)
state.last_time = new_time
state.last_completion_tokens = completion_tokens
if state.finished:
self.metrics_collector.observe_one_finished_request(
recv_obj.prompt_tokens[i], completion_tokens
recv_obj.prompt_tokens[i],
completion_tokens,
state.finished_time - state.created_time,
)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
# Compute time_per_output_token for the non-streaming case
if (
hasattr(state.obj, "stream")
and not state.obj.stream
and completion_tokens >= 1
):
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) / completion_tokens
)
def dump_requests(self, state: ReqState, out_dict: dict):
self.dump_request_list.append(
@@ -996,22 +1155,38 @@ T = TypeVar("T")
class _Communicator(Generic[T]):
"""Note: The communicator now only run up to 1 in-flight request at any time."""
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
self._result_future: Optional[asyncio.Future] = None
self._result_event: Optional[asyncio.Event] = None
self._result_values: Optional[List[T]] = None
self._ready_queue: Deque[asyncio.Future] = deque()
async def __call__(self, obj):
self._sender.send_pyobj(obj)
self._result_future = asyncio.Future()
ready_event = asyncio.Event()
if self._result_event is not None or len(self._ready_queue) > 0:
self._ready_queue.append(ready_event)
await ready_event.wait()
assert self._result_event is None
assert self._result_values is None
if obj:
self._sender.send_pyobj(obj)
self._result_event = asyncio.Event()
self._result_values = []
await self._result_future
await self._result_event.wait()
result_values = self._result_values
self._result_future = self._result_values = None
self._result_event = self._result_values = None
if len(self._ready_queue) > 0:
self._ready_queue.popleft().set()
return result_values
def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_future.set_result(None)
self._result_event.set()

View File

@@ -15,10 +15,13 @@
import logging
import threading
from typing import Optional
from typing import Optional, Tuple
import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
@@ -159,7 +162,7 @@ class TpModelWorker:
model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False,
):
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
if launch_done:

View File

@@ -175,7 +175,7 @@ class TpModelWorkerClient:
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
@@ -188,8 +188,7 @@ class TpModelWorkerClient:
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info,
sampling_info_done=threading.Event(),
scaling_penalties=sampling_info.scaling_penalties,
linear_penalties=sampling_info.linear_penalties,
penalizer_orchestrator=None,
)
# A cuda stream sync here to avoid the cuda illegal memory access error.

View File

@@ -2,7 +2,9 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -12,7 +14,7 @@ if TYPE_CHECKING:
class ChunkCacheEntry:
def __init__(self, rid, value):
def __init__(self, rid: str, value: torch.Tensor):
self.rid = rid
self.value = value
@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset()
@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache):
if req.rid in self.entries:
del self.entries[req.rid]
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None:
token_id_len = len(req.fill_ids)
else:
token_id_len = len(token_ids)
def cache_unfinished_req(self, req: Req):
token_id_len = len(req.fill_ids)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len
@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache):
def evictable_size(self):
return 0
def pretty_print(self):
return ""
def protected_size(self):
return 0

View File

@@ -13,6 +13,7 @@
# ==============================================================================
"""Utilities for Prometheus Metrics Collection."""
import time
from dataclasses import dataclass
from typing import Dict, Union
@@ -35,19 +36,20 @@ class SchedulerMetricsCollector:
from prometheus_client import Gauge
self.labels = labels
self.last_log_time = time.time()
self.num_running_reqs = Gauge(
name="sglang:num_running_reqs",
documentation="The number of running requests.",
labelnames=labels.keys(),
multiprocess_mode="sum",
multiprocess_mode="mostrecent",
)
self.num_used_tokens = Gauge(
name="sglang:num_used_tokens",
documentation="The number of used tokens.",
labelnames=labels.keys(),
multiprocess_mode="sum",
multiprocess_mode="mostrecent",
)
self.token_usage = Gauge(
@@ -61,14 +63,14 @@ class SchedulerMetricsCollector:
name="sglang:gen_throughput",
documentation="The generation throughput (token/s).",
labelnames=labels.keys(),
multiprocess_mode="sum",
multiprocess_mode="mostrecent",
)
self.num_queue_reqs = Gauge(
name="sglang:num_queue_reqs",
documentation="The number of requests in the waiting queue.",
labelnames=labels.keys(),
multiprocess_mode="sum",
multiprocess_mode="mostrecent",
)
self.cache_hit_rate = Gauge(
@@ -97,6 +99,7 @@ class SchedulerMetricsCollector:
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
self.last_log_time = time.time()
class TokenizerMetricsCollector:
@@ -130,12 +133,15 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(),
buckets=[
0.1,
0.25,
0.3,
0.5,
0.75,
0.7,
0.9,
1,
2,
5,
4,
6,
8,
10,
20,
40,
@@ -151,24 +157,56 @@ class TokenizerMetricsCollector:
documentation="Histogram of time per output token in seconds.",
labelnames=labels.keys(),
buckets=[
0.002,
0.005,
0.01,
0.010,
0.020,
0.030,
0.040,
0.050,
0.060,
0.070,
0.080,
0.090,
0.100,
0.150,
0.200,
0.300,
0.400,
0.600,
0.800,
1.000,
2.000,
],
)
self.histogram_inter_token_latency_seconds = Histogram(
name="sglang:inter_token_latency_seconds",
documentation="Histogram of inter-token latency in seconds.",
labelnames=labels.keys(),
buckets=[
0.002,
0.004,
0.006,
0.008,
0.010,
0.015,
0.02,
0.020,
0.025,
0.03,
0.04,
0.05,
0.030,
0.035,
0.040,
0.050,
0.075,
0.1,
0.15,
0.2,
0.3,
0.4,
0.5,
0.75,
1.0,
2.5,
0.100,
0.150,
0.200,
0.300,
0.400,
0.500,
0.750,
1.000,
2.000,
],
)
@@ -178,8 +216,9 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(),
buckets=[
0.1,
0.25,
0.5,
0.2,
0.4,
0.8,
1,
2,
5,
@@ -188,28 +227,161 @@ class TokenizerMetricsCollector:
40,
60,
80,
100,
150,
200,
250,
300,
350,
500,
1000,
],
)
self.histogram_prefill_prealloc_duration = Histogram(
name="sglang:prefill_prealloc_duration_seconds",
documentation="Histogram of prefill prealloc duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
1,
2,
4,
6,
8,
10,
20,
40,
60,
80,
120,
160,
],
)
self.histogram_prefill_queue_duration = Histogram(
name="sglang:prefill_queue_duration_seconds",
documentation="Histogram of prefill queue duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
self.histogram_prefill_forward_duration = Histogram(
name="sglang:prefill_forward_duration_seconds",
documentation="Histogram of prefill forward duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
self.histogram_prefill_transfer_duration = Histogram(
name="sglang:prefill_transfer_duration_seconds",
documentation="Histogram of prefill transfer duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.050,
0.100,
0.150,
0.200,
0.300,
0.400,
0.500,
1.000,
2.000,
],
)
self.histogram_decode_prealloc_duration = Histogram(
name="sglang:decode_prealloc_duration_seconds",
documentation="Histogram of decode prealloc duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
self.histogram_decode_queue_duration = Histogram(
name="sglang:decode_queue_duration_seconds",
documentation="Histogram of decode queue duration in seconds.",
labelnames=labels.keys(),
buckets=[
0.1,
0.3,
0.5,
0.7,
0.9,
2,
4,
8,
16,
64,
],
)
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
histogram.labels(**self.labels).observe(data)
def _log_counter(self, counter, data: Union[int, float]) -> None:
# Convenience function for logging to counter.
counter.labels(**self.labels).inc(data)
def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
def observe_one_finished_request(
self,
prompt_tokens: int,
generation_tokens: int,
e2e_latency: float,
):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
self.num_requests_total.labels(**self.labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
if generation_tokens >= 1:
self.histogram_time_per_output_token.labels(**self.labels).observe(
e2e_latency / generation_tokens
)
def observe_time_to_first_token(self, value: Union[float, int]):
self._log_histogram(self.histogram_time_to_first_token, value)
def observe_time_to_first_token(self, value: float):
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
def observe_time_per_output_token(self, value: Union[float, int]):
self._log_histogram(self.histogram_time_per_output_token, value)
def observe_inter_token_latency(self, internval: float, num_new_tokens: int):
adjusted_interval = internval / num_new_tokens
def observe_e2e_request_latency(self, value: Union[float, int]):
self._log_histogram(self.histogram_e2e_request_latency, value)
# A faster version of the Histogram::observe which observes multiple values at the same time.
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
his = self.histogram_inter_token_latency_seconds.labels(**self.labels)
his._sum.inc(internval)
for i, bound in enumerate(his._upper_bounds):
if adjusted_interval <= bound:
his._buckets[i].inc(num_new_tokens)
break

View File

@@ -109,11 +109,15 @@ def set_torch_compile_config():
def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
if capture_bs is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 128]
if server_args.speculative_algorithm is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
else:
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
capture_bs = list(range(1, 33))
if is_hip_:
capture_bs += [i * 8 for i in range(21, 33)]
@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
)
)
)
capture_bs = [
bs
for bs in capture_bs
@@ -385,9 +390,6 @@ class CudaGraphRunner:
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
@@ -401,12 +403,11 @@ class CudaGraphRunner:
global_graph_memory_pool = graph.pool()
return graph, out
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
def recapture_if_needed(self, forward_batch: ForwardBatch):
# If the capture_hidden_mode changes, we need to recapture the graph
hidden_mode_from_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
# If the capture_hidden_mode changes, we need to recapture the graph
if (
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
and self.capture_hidden_mode != CaptureHiddenMode.FULL
@@ -420,6 +421,9 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
def replay(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs

View File

@@ -31,7 +31,7 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Union
import torch
import triton
@@ -46,7 +46,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
class ForwardMode(IntEnum):
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
class CaptureHiddenMode(IntEnum):
NULL = auto()
# Capture hidden states of all tokens.
FULL = auto()
# Capture a hidden state of the last token.
LAST = auto()
def need_capture(self):
@@ -148,6 +151,7 @@ class ForwardBatch:
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# Position information
positions: torch.Tensor = None
@@ -160,6 +164,7 @@ class ForwardBatch:
extend_prefix_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
# For multimodal
image_inputs: Optional[List[ImageInputs]] = None
@@ -190,10 +195,13 @@ class ForwardBatch:
can_run_dp_cuda_graph: bool = False
# Speculative decoding
spec_info: SpecInfo = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None
# For padding
padded_static_len: int = -1 # -1 if not padded
# For Qwen2-VL
mrope_positions: torch.Tensor = None
@@ -203,8 +211,13 @@ class ForwardBatch:
batch: ModelWorkerBatch,
model_runner: ModelRunner,
):
device = model_runner.device
extend_input_logprob_token_ids_gpu = None
if batch.extend_input_logprob_token_ids is not None:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
@@ -220,6 +233,7 @@ class ForwardBatch:
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
@@ -231,6 +245,7 @@ class ForwardBatch:
spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
)
if ret.global_num_tokens is not None:
@@ -341,6 +356,7 @@ class ForwardBatch:
)
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.concat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
axis=1,
@@ -379,7 +395,7 @@ def compute_position_kernel(
extend_seq_lens,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
pid = tl.program_id(0).to(tl.int64)
prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid)

View File

@@ -13,9 +13,12 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
import collections
import datetime
import gc
import json
import logging
import os
import time
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
@@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -73,10 +77,15 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes,
set_cuda_arch,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
class ModelRunner:
"""ModelRunner runs the forward passes of the models."""
@@ -180,9 +189,13 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
}
)
@@ -199,6 +212,18 @@ class ModelRunner:
self.sampler = Sampler()
self.load_model()
# Handle the case where some of models don't finish loading.
try:
dist.monitored_barrier(
group=get_tp_group().cpu_group,
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
wait_all_ranks=True,
)
except RuntimeError:
raise ValueError(
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
) from None
# Apply torchao quantization
torchao_applied = getattr(self.model, "torchao_applied", False)
# In layered loading, torchao may have been applied
@@ -625,6 +650,9 @@ class ModelRunner:
4096,
)
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
@@ -655,6 +683,7 @@ class ModelRunner:
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
@@ -758,9 +787,13 @@ class ModelRunner:
return
tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.")
logger.info(
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
logger.info(
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
@@ -820,11 +853,10 @@ class ModelRunner:
else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
def sample(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor:
def _preprocess_logits(
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
):
# Apply logit bias
sampling_info = forward_batch.sampling_info
if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
@@ -833,15 +865,77 @@ class ModelRunner:
else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask()
sampling_info.update_penalties()
sampling_info.apply_logits_bias(logits_output.next_token_logits)
def update_output_logprobs(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
top_logprobs_nums: List[int],
token_ids_logprobs: List[int],
next_token_ids: torch.Tensor,
*,
num_tokens_per_req: List[int],
):
"""Update the logits_output's output logprob based on next_token_ids
Args:
logits_output: The logits output from the model forward
sampling_info: Sampling info for logprob calculation
top_logprobs_nums: Number of logprobs per request.
next_token_ids: Next token ids.
num_tokens_per_req: The number of tokens per request.
Returns:
A list of next_token_ids
"""
self._preprocess_logits(logits_output, sampling_info)
# We should repeat top_logprobs_nums to match num_tokens_per_req.
top_logprobs_nums_repeat_interleaved = []
token_ids_logprobs_repeat_interleaved = []
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
self.sampler(
logits_output,
sampling_info,
True,
top_logprobs_nums_repeat_interleaved,
token_ids_logprobs_repeat_interleaved,
batch_next_token_ids=next_token_ids,
)
def sample(
self,
logits_output: LogitsProcessorOutput,
forward_batch: ForwardBatch,
) -> torch.Tensor:
"""Sample and compute logprobs and update logits_output.
Args:
logits_output: The logits output from the model forward
forward_batch: The forward batch that generates logits_output
Returns:
A list of next_token_ids
"""
# For duplex models with multiple output streams.
if isinstance(logits_output, tuple):
return torch.stack(
[self.sample(values, forward_batch) for values in logits_output],
axis=-1,
)
self._preprocess_logits(logits_output, forward_batch.sampling_info)
# Sample the next tokens
next_token_ids = self.sampler(
logits_output,
sampling_info,
forward_batch.sampling_info,
forward_batch.return_logprob,
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
)
return next_token_ids

View File

@@ -25,10 +25,10 @@ import filelock
import gguf
import huggingface_hub.constants
import numpy as np
import safetensors.torch
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from sglang.srt.configs.load_config import LoadConfig
@@ -62,7 +62,6 @@ enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file(
)
# check if the tensors are the same
reloaded = load_file(sf_filename)
reloaded = safetensors.torch.load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file(
def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig
) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file
@@ -402,15 +400,34 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param)
def decrypt(fn, key):
raise NotImplementedError()
def safetensors_encrypted_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
):
raise NotImplementedError()
def safetensors_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one.
"""
if decryption_key:
yield from safetensors_encrypted_weights_iterator(
hf_weights_files, is_all_weights_sharded, decryption_key
)
return
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
@@ -420,15 +437,9 @@ def safetensors_weights_iterator(
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
if not is_all_weights_sharded:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
result = load_file(st_file, device="cpu")
for name, param in result.items():
yield name, param
result = safetensors.torch.load_file(st_file, device="cpu")
for name, param in result.items():
yield name, param
def pt_weights_iterator(

View File

@@ -1,13 +1,11 @@
from .orchestrator import BatchedPenalizerOrchestrator
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
from .penalizers.presence_penalty import BatchedPresencePenalizer
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
from sglang.srt.sampling.penaltylib.frequency_penalty import BatchedFrequencyPenalizer
from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer
from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator
from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer
__all__ = [
"BatchedFrequencyPenalizer",
"BatchedMinNewTokensPenalizer",
"BatchedPresencePenalizer",
"BatchedRepetitionPenalizer",
"BatchedPenalizerOrchestrator",
]

View File

@@ -0,0 +1,66 @@
import torch
from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
)
class BatchedFrequencyPenalizer(_BatchedPenalizer):
"""
Frequency penalizer penalizes tokens based on their frequency in the output.
"""
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
self._is_prepared = False
def _is_required(self) -> bool:
return any(
req.sampling_params.frequency_penalty != 0.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_frequency_penalties = torch.zeros(
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
dtype=torch.float32,
device=self.orchestrator.device,
)
self.frequency_penalties = (
torch.tensor(
data=[
req.sampling_params.frequency_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
).unsqueeze_(1)
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
self.cumulated_frequency_penalties.scatter_add_(
dim=1,
index=output_ids.unsqueeze(1),
src=self.frequency_penalties,
)
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits.sub_(self.cumulated_frequency_penalties)
def _filter(self, keep_indices: torch.Tensor):
self.frequency_penalties = self.frequency_penalties[keep_indices]
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
keep_indices
]
def _merge(self, their: "BatchedFrequencyPenalizer"):
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
self.frequency_penalties = torch.cat(
[self.frequency_penalties, their.frequency_penalties], dim=0
)
self.cumulated_frequency_penalties = torch.cat(
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
dim=0,
)

View File

@@ -1,8 +1,9 @@
from typing import List
import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
)
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
@@ -10,9 +11,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
Min new tokens penalizer penalizes tokens based on the length of the output.
"""
min_new_tokens: torch.Tensor = None
stop_token_penalties: torch.Tensor = None
len_output_tokens: torch.Tensor = None
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
self._is_prepared = False
def _is_required(self) -> bool:
return any(
@@ -47,7 +48,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
padding_value=self.orchestrator.vocab_size,
)
self.stop_token_penalties = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
size=(len(self.orchestrator.reqs()), self.orchestrator.vocab_size + 1),
dtype=torch.float32,
device=self.orchestrator.device,
).scatter_add_(
@@ -64,31 +65,22 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
]
self.len_output_tokens = torch.zeros(
size=(self.orchestrator.batch_size(), 1),
size=(len(self.orchestrator.reqs()), 1),
dtype=torch.int32,
device=self.orchestrator.device,
)
def _teardown(self):
self.min_new_tokens = None
self.stop_token_penalties = None
self.len_output_tokens = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
self.len_output_tokens += 1
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
def _apply(self, logits: torch.Tensor):
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
logits[mask] += self.stop_token_penalties[mask]
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
def _filter(self, keep_indices: torch.Tensor):
self.min_new_tokens = self.min_new_tokens[keep_indices]
self.stop_token_penalties = self.stop_token_penalties[keep_indices]
self.len_output_tokens = self.len_output_tokens[keep_indices]
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
self.min_new_tokens = torch.cat(

View File

@@ -1,35 +1,25 @@
from __future__ import annotations
import abc
import dataclasses
from typing import List, Set, Type, Union
from typing import TYPE_CHECKING, Set, Type
import torch
@dataclasses.dataclass
class _ReqLike:
origin_input_ids: List[int]
@dataclasses.dataclass
class _BatchLike:
reqs: List[_ReqLike]
def batch_size(self):
return len(self.reqs)
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
class BatchedPenalizerOrchestrator:
def __init__(
self,
vocab_size: int,
batch: _BatchLike,
device: str,
Penalizers: Set[Type["_BatchedPenalizer"]],
batch: ScheduleBatch,
penalizers: Set[Type["_BatchedPenalizer"]],
):
self.vocab_size = vocab_size
self.batch = batch
self.device = device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
self.device = batch.device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
is_required = False
for penalizer in self.penalizers.values():
@@ -37,31 +27,9 @@ class BatchedPenalizerOrchestrator:
is_required |= pen_is_required
self.is_required = is_required
input_ids = [
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
for req in self.reqs()
]
if self.is_required:
self.cumulate_input_tokens(input_ids=input_ids)
def reqs(self):
return self.batch.reqs
def batch_size(self):
return self.batch.batch_size()
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
"""
Feed the input tokens to the penalizers.
Args:
input_ids (List[torch.Tensor]): The input tokens.
"""
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_input_tokens(input_ids=token_ids)
def cumulate_output_tokens(self, output_ids: torch.Tensor):
"""
Feed the output tokens to the penalizers.
@@ -69,13 +37,8 @@ class BatchedPenalizerOrchestrator:
Args:
output_ids (torch.Tensor): The output tokens.
"""
if not self.is_required:
return
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
for penalizer in self.penalizers.values():
penalizer.cumulate_output_tokens(output_ids=token_ids)
penalizer.cumulate_output_tokens(output_ids=output_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
"""
@@ -88,48 +51,33 @@ class BatchedPenalizerOrchestrator:
Returns:
torch.Tensor: The logits after applying the penalizers.
"""
if not self.is_required:
return
for penalizer in self.penalizers.values():
logits = penalizer.apply(logits)
penalizer.apply(logits)
return logits
def filter(
self,
indices_to_keep: List[int],
indices_tensor_to_keep: torch.Tensor = None,
):
def filter(self, keep_indices: torch.Tensor):
"""
Filter the penalizers based on the indices to keep in the batch.
Args:
indices_to_keep (List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
keep_indices (torch.Tensor): Tensor of indices to keep in the batch.
"""
if not self.is_required:
return
empty_indices = len(indices_to_keep) == 0
if len(keep_indices) == 0:
self.is_required = False
for penalizer in self.penalizers.values():
penalizer.teardown()
return
is_required = False
for penalizer in self.penalizers.values():
tmp_is_required = penalizer.is_required()
is_required = is_required or tmp_is_required
if not tmp_is_required or empty_indices:
penalizer.teardown()
is_required |= tmp_is_required
if tmp_is_required:
penalizer.filter(keep_indices=keep_indices)
else:
# create tensor index only when it's needed
if indices_tensor_to_keep is None:
indices_tensor_to_keep = torch.tensor(
indices_to_keep, dtype=torch.int32, device=self.device
)
penalizer.filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
penalizer.teardown()
self.is_required = is_required
def merge(self, their: "BatchedPenalizerOrchestrator"):
@@ -146,75 +94,9 @@ class BatchedPenalizerOrchestrator:
if not self.is_required and not their.is_required:
return
self.is_required |= their.is_required
for Penalizer, their_penalizer in their.penalizers.items():
if Penalizer not in self.penalizers:
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
self.penalizers[Penalizer].merge(their_penalizer)
class _TokenIDs:
"""
A class that wraps token IDs to provide additional utility functions to penalizers.
Attributes:
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
cached_counts (torch.Tensor): The cached occurrence count tensor.
"""
def __init__(
self,
orchestrator: BatchedPenalizerOrchestrator,
token_ids: Union[torch.Tensor, List[torch.Tensor]],
):
self.orchestrator = orchestrator
self.token_ids = token_ids
self.cached_counts = None
def occurrence_count(self) -> torch.Tensor:
"""
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
Returns:
torch.Tensor: The occurrence count tensor.
"""
if self.cached_counts is not None:
return self.cached_counts
token_ids = self.token_ids
if isinstance(token_ids, list):
# TODO: optimize this part
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=token_ids,
batch_first=True,
padding_value=self.orchestrator.vocab_size,
)
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.int64,
device=self.orchestrator.device,
).scatter_add_(
dim=1,
index=padded_token_ids,
src=torch.ones_like(padded_token_ids),
)[
:, : self.orchestrator.vocab_size
]
else:
# TODO: optimize this part. We do not need to create this big tensor every time.
# We can directly apply the results on the logits.
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
device=self.orchestrator.device,
)
self.cached_counts[
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
] = 1
return self.cached_counts
self.is_required = True
for penalizer, their_penalizer in their.penalizers.items():
self.penalizers[penalizer].merge(their_penalizer)
class _BatchedPenalizer(abc.ABC):
@@ -222,10 +104,6 @@ class _BatchedPenalizer(abc.ABC):
An abstract class for a batched penalizer.
"""
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
self._is_prepared = False
def is_prepared(self) -> bool:
return self._is_prepared
@@ -233,51 +111,40 @@ class _BatchedPenalizer(abc.ABC):
return self._is_required()
def prepare(self):
if not self.is_prepared():
if not self._is_prepared:
self._prepare()
self._is_prepared = True
def prepare_if_required(self):
if self.is_required():
if self._is_required():
self.prepare()
return True
else:
return False
def teardown(self):
if self.is_prepared():
self._teardown()
self._is_prepared = False
self._is_prepared = False
def cumulate_input_tokens(self, input_ids: _TokenIDs):
if not self.is_prepared():
return
self._cumulate_input_tokens(input_ids=input_ids)
def cumulate_output_tokens(self, output_ids: _TokenIDs):
if not self.is_prepared():
def cumulate_output_tokens(self, output_ids: torch.Tensor):
if not self._is_prepared:
return
self._cumulate_output_tokens(output_ids=output_ids)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.is_prepared():
return logits
return self._apply(logits=logits)
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
if not self.is_prepared():
if not self._is_prepared:
return
self._filter(
indices_to_keep=indices_to_keep,
indices_tensor_to_keep=indices_tensor_to_keep,
)
self._apply(logits=logits)
def filter(self, keep_indices: torch.Tensor):
if not self._is_prepared:
return
self._filter(keep_indices=keep_indices)
def merge(self, their: "_BatchedPenalizer"):
if not self.is_prepared() and not their.is_prepared():
if not self._is_prepared and not their._is_prepared:
return
self.prepare()
@@ -300,23 +167,7 @@ class _BatchedPenalizer(abc.ABC):
pass
@abc.abstractmethod
def _teardown(self):
"""
Tear down the penalizer.
Usually, this is where the penalizer frees its tensors.
"""
pass
@abc.abstractmethod
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
"""
Cumulate the input tokens.
Orchestrator will call this function to feed the input tokens to the penalizer.
"""
pass
@abc.abstractmethod
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
"""
Cumulate the output tokens.
Orchestrator will call this function to feed the output tokens to the penalizer.
@@ -332,7 +183,7 @@ class _BatchedPenalizer(abc.ABC):
pass
@abc.abstractmethod
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
def _filter(self, keep_indices: torch.Tensor):
"""
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
"""

View File

@@ -1,75 +0,0 @@
from typing import List
import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedFrequencyPenalizer(_BatchedPenalizer):
"""
Frequency penalizer penalizes tokens based on their frequency in the output.
"""
frequency_penalties: torch.Tensor = None
cumulated_frequency_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.frequency_penalty != 0.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_frequency_penalties = (
torch.tensor(
data=[0.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.frequency_penalties = (
torch.tensor(
data=[
req.sampling_params.frequency_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_frequency_penalties)
)
def _teardown(self):
self.frequency_penalties = None
self.cumulated_frequency_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
self.cumulated_frequency_penalties += (
self.frequency_penalties * output_ids.occurrence_count()
)
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits -= self.cumulated_frequency_penalties
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedFrequencyPenalizer"):
self.frequency_penalties = torch.cat(
[self.frequency_penalties, their.frequency_penalties], dim=0
)
self.cumulated_frequency_penalties = torch.cat(
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
dim=0,
)

View File

@@ -1,74 +0,0 @@
from typing import List
import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedPresencePenalizer(_BatchedPenalizer):
"""
Presence penalizer penalizes tokens based on their presence in the output.
"""
presence_penalties: torch.Tensor = None
cumulated_presence_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.presence_penalty != 0.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_presence_penalties = (
torch.tensor(
data=[0.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.presence_penalties = (
torch.tensor(
data=[
req.sampling_params.presence_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_presence_penalties)
)
def _teardown(self):
self.presence_penalties = None
self.cumulated_presence_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
pass
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits -= self.cumulated_presence_penalties
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedPresencePenalizer"):
self.presence_penalties = torch.cat(
[self.presence_penalties, their.presence_penalties], dim=0
)
self.cumulated_presence_penalties = torch.cat(
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
dim=0,
)

View File

@@ -1,85 +0,0 @@
from typing import List
import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.utils import get_compiler_backend
@torch.compile(dynamic=True, backend=get_compiler_backend())
def apply_scaling_penalties(logits, scaling_penalties):
logits[:] = torch.where(
logits > 0,
logits / scaling_penalties,
logits * scaling_penalties,
)
class BatchedRepetitionPenalizer(_BatchedPenalizer):
"""
Repetition penalizer penalizes tokens based on their repetition in the input and output.
"""
repetition_penalties: torch.Tensor = None
cumulated_repetition_penalties: torch.Tensor = None
def _is_required(self) -> bool:
return any(
req.sampling_params.repetition_penalty != 1.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_repetition_penalties = (
torch.tensor(
data=[1.0 for _ in self.orchestrator.reqs()],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.repeat(1, self.orchestrator.vocab_size)
)
self.repetition_penalties = (
torch.tensor(
data=[
req.sampling_params.repetition_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
.unsqueeze_(1)
.expand_as(self.cumulated_repetition_penalties)
)
def _teardown(self):
self.repetition_penalties = None
self.cumulated_repetition_penalties = None
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
mask = input_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
mask = output_ids.occurrence_count() > 0
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
return logits
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
indices_tensor_to_keep
]
def _merge(self, their: "BatchedRepetitionPenalizer"):
self.repetition_penalties = torch.cat(
[self.repetition_penalties, their.repetition_penalties], dim=0
)
self.cumulated_repetition_penalties = torch.cat(
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
dim=0,
)

View File

@@ -0,0 +1,66 @@
import torch
from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
)
class BatchedPresencePenalizer(_BatchedPenalizer):
"""
Presence penalizer penalizes tokens based on their presence in the output.
"""
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
self._is_prepared = False
def _is_required(self) -> bool:
return any(
req.sampling_params.presence_penalty != 0.0
for req in self.orchestrator.reqs()
)
def _prepare(self):
self.cumulated_presence_penalties = torch.zeros(
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
dtype=torch.float32,
device=self.orchestrator.device,
)
self.presence_penalties = (
torch.tensor(
data=[
req.sampling_params.presence_penalty
for req in self.orchestrator.reqs()
],
dtype=torch.float32,
device=self.orchestrator.device,
)
).unsqueeze_(1)
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
self.cumulated_presence_penalties.scatter_(
dim=1,
index=output_ids.unsqueeze(1),
src=self.presence_penalties,
)
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
logits.sub_(self.cumulated_presence_penalties)
def _filter(self, keep_indices: torch.Tensor):
self.presence_penalties = self.presence_penalties[keep_indices]
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
keep_indices
]
def _merge(self, their: "BatchedPresencePenalizer"):
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
self.presence_penalties = torch.cat(
[self.presence_penalties, their.presence_penalties], dim=0
)
self.cumulated_presence_penalties = torch.cat(
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
dim=0,
)

View File

@@ -9,9 +9,6 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
apply_scaling_penalties,
)
logger = logging.getLogger(__name__)
@@ -22,49 +19,45 @@ if TYPE_CHECKING:
@dataclasses.dataclass
class SamplingBatchInfo:
# Batched sampling params
# Basic batched sampling params
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
# All requests use greedy sampling
# Whether all requests use greedy sampling
is_all_greedy: bool
# Dispatch in CUDA graph
# Whether any request needs min_p sampling
need_min_p_sampling: bool
# Whether any request has custom logit processor
has_custom_logit_processor: bool
# Bias Tensors
# Masking tensors for grammar-guided structured outputs
vocab_size: int
grammars: Optional[List] = None
sampling_info_done: Optional[threading.Event] = None
logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
# An event used for overlap schedule
sampling_info_done: Optional[threading.Event] = None
# Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
linear_penalties: Optional[torch.Tensor] = None
scaling_penalties: Optional[torch.Tensor] = None
linear_penalty: torch.Tensor = None
# Device
device: str = "cuda"
# Custom Parameters
# Whether any request has custom logit processor
has_custom_logit_processor: bool = False
# Custom parameters
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
# Custom Logit Processor
# Custom logit processor
custom_logit_processor: Optional[
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
] = None
# Device
device: str = "cuda"
@classmethod
def from_schedule_batch(
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
):
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs
device = batch.device
temperatures = (
@@ -118,106 +111,60 @@ class SamplingBatchInfo:
merged_custom_logit_processor = None
custom_params = None
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
#
# While we can choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge_batch()} cases as well.
penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=batch,
penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
},
)
ret = cls(
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
has_custom_logit_processor=has_custom_logit_processor,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
device=device,
penalizer_orchestrator=penalizer_orchestrator,
has_custom_logit_processor=has_custom_logit_processor,
custom_params=custom_params,
custom_logit_processor=merged_custom_logit_processor,
device=device,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
if enable_overlap_schedule:
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
# so it is kind of tricky to make it work with overlap scheduler.
# It requires correcly updating the penalty logits before the sampling and syncing the events.
# We will support them later.
penalizers = {
penaltylib.BatchedMinNewTokensPenalizer,
}
if (
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
):
logger.warning(
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
"when using the default overlap scheduler. They will be ignored. "
"Please add `--disable-overlap` when launching the server if you need these features. "
"The speed will be slower in that case."
)
else:
penalizers = {
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
}
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge_batch()} cases as well.
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=batch,
device=batch.device,
Penalizers=penalizers,
)
# Handle logit bias but only allocate when needed
ret.logit_bias = None
return ret
def __len__(self):
return len(self.temperatures)
def update_penalties(self):
self.scaling_penalties = None
self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values():
if not penalizer.is_prepared():
continue
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self):
if not self.grammars:
self.vocab_mask = None
self.apply_mask = None
self.apply_mask_func = None
return
# find a grammar from the list
# Find a grammar from the list
first_grammar = next(grammar for grammar in self.grammars if grammar)
# maybe we can reuse the existing mask?
# TODO(lianmin): Maybe we can reuse the existing mask?
self.vocab_mask = first_grammar.allocate_vocab_mask(
vocab_size=self.vocab_size,
batch_size=len(self.temperatures),
device=self.device,
)
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
self.apply_mask_func = (
first_grammar.apply_vocab_mask
) # force to use static method
# Apply the mask
for i, grammar in enumerate(self.grammars):
@@ -227,35 +174,56 @@ class SamplingBatchInfo:
# Move the mask to the device if needed
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
def update_penalties(self):
if self.penalizer_orchestrator.is_required:
self.linear_penalty = torch.zeros(
(len(self.temperatures), self.vocab_size),
dtype=torch.float32,
device=self.temperatures.device,
)
self.penalizer_orchestrator.apply(self.linear_penalty)
else:
self.linear_penalty = None
def apply_logits_bias(self, logits: torch.Tensor):
if self.linear_penalty is not None:
# Used in the overlap mode
logits.add_(self.linear_penalty)
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
# Used in the non-overlap mode
self.penalizer_orchestrator.apply(logits)
if self.vocab_mask is not None:
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
self.penalizer_orchestrator.filter(keep_indices_device)
if self.has_custom_logit_processor:
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device)
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
"logit_bias",
]:
value = getattr(self, item, None)
if value is not None: # logit_bias can be None
setattr(self, item, value[new_indices])
setattr(self, item, value[keep_indices_device])
def _filter_batch_custom_logit_processor(
self, unfinished_indices: List[int], new_indices: torch.Tensor
self, keep_indices: List[int], keep_indices_device: torch.Tensor
):
"""Filter the custom logit processor and custom params"""
self.custom_logit_processor = {
k: (p, mask[new_indices])
k: (p, mask[keep_indices_device])
for k, (p, mask) in self.custom_logit_processor.items()
if any(
mask[new_indices]
if torch.any(
mask[keep_indices_device]
) # ignore the custom logit processor whose mask is all False
}
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
self.custom_params = [self.custom_params[i] for i in keep_indices]
# If the custom logit processor is an empty dict, set the flag to False,
# and set the custom logit processor and custom params to None.
@@ -264,31 +232,6 @@ class SamplingBatchInfo:
self.custom_params = None
self.has_custom_logit_processor = False
@staticmethod
def merge_bias_tensor(
lhs: torch.Tensor,
rhs: torch.Tensor,
bs1: int,
bs2: int,
device: str,
default: int = 0,
):
# bias tensor can be None
if lhs is not None or rhs is not None:
shape, dtype = None, None
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
with torch.dtype(dtype):
if lhs is None:
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
return torch.cat([lhs, rhs])
return None
@staticmethod
def merge_custom_logit_processor(
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
@@ -332,11 +275,6 @@ class SamplingBatchInfo:
def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
# Merge the logit bias tensor
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
# Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors
@@ -370,22 +308,5 @@ class SamplingBatchInfo:
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias
if self.logit_bias is not None:
logits.add_(self.logit_bias)
# min-token, presence, frequency
if self.linear_penalties is not None:
logits.add_(self.linear_penalties)
# repetition
if self.scaling_penalties is not None:
apply_scaling_penalties(logits, self.scaling_penalties)
# Apply regex vocab_mask
if self.vocab_mask is not None:
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
self.is_all_greedy |= other.is_all_greedy
self.need_min_p_sampling |= other.need_min_p_sampling

View File

@@ -15,15 +15,21 @@
import argparse
import dataclasses
import json
import logging
import os
import random
import subprocess
import tempfile
import uuid
from pathlib import Path
from typing import List, Optional
import torch
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import (
create_checksum,
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
get_nvgpu_memory_capacity,
@@ -43,12 +49,13 @@ class ServerArgs:
model_path: str
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
load_format: str = "auto"
trust_remote_code: bool = True
trust_remote_code: bool = False
dtype: str = "auto"
kv_cache_dtype: str = "auto"
quantization_param_path: nullable_str = None
quantization: Optional[str] = None
quantization_param_path: nullable_str = None
context_length: Optional[int] = None
device: str = "cuda"
served_model_name: Optional[str] = None
@@ -67,7 +74,7 @@ class ServerArgs:
max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384
schedule_policy: str = "lpm"
schedule_policy: str = "fcfs"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
prefill_only_one_req: bool = False
@@ -88,6 +95,7 @@ class ServerArgs:
log_level: str = "info"
log_level_http: Optional[str] = None
log_requests: bool = False
log_requests_level: int = 0
show_time_cost: bool = False
enable_metrics: bool = False
decode_log_interval: int = 40
@@ -123,11 +131,13 @@ class ServerArgs:
grammar_backend: Optional[str] = "outlines"
# Speculative decoding
speculative_draft_model_path: Optional[str] = None
speculative_algorithm: Optional[str] = None
speculative_draft_model_path: Optional[str] = None
speculative_num_steps: int = 5
speculative_eagle_topk: int = 8
speculative_num_draft_tokens: int = 64
speculative_eagle_topk: int = 4
speculative_num_draft_tokens: int = 8
speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
# Double Sparsity
@@ -169,6 +179,12 @@ class ServerArgs:
enable_hierarchical_cache: bool = False
enable_flashinfer_mla: bool = False
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False
def __post_init__(self):
# Set missing default values
@@ -266,10 +282,10 @@ class ServerArgs:
self.speculative_algorithm == "EAGLE"
or self.speculative_algorithm == "NEXTN"
):
self.disable_overlap_schedule = True
self.prefill_only_one_req = True
self.disable_cuda_graph_padding = True
self.disable_radix_cache = True
self.disable_overlap_schedule = True
self.chunked_prefill_size = -1
logger.info(
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
@@ -377,15 +393,6 @@ class ServerArgs:
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
)
parser.add_argument(
"--quantization-param-path",
type=nullable_str,
default=None,
help="Path to the JSON file containing the KV cache "
"scaling factors. This should generally be supplied, when "
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. ",
)
parser.add_argument(
"--quantization",
type=str,
@@ -404,6 +411,15 @@ class ServerArgs:
],
help="The quantization method.",
)
parser.add_argument(
"--quantization-param-path",
type=nullable_str,
default=None,
help="Path to the JSON file containing the KV cache "
"scaling factors. This should generally be supplied, when "
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. ",
)
parser.add_argument(
"--context-length",
type=int,
@@ -578,7 +594,14 @@ class ServerArgs:
parser.add_argument(
"--log-requests",
action="store_true",
help="Log the inputs and outputs of all requests.",
help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
)
parser.add_argument(
"--log-requests-level",
type=int,
default=0,
help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
choices=[0, 1, 2],
)
parser.add_argument(
"--show-time-cost",
@@ -742,16 +765,28 @@ class ServerArgs:
parser.add_argument(
"--speculative-eagle-topk",
type=int,
help="The number of token sampled from draft model in eagle2 each step.",
help="The number of tokens sampled from the draft model in eagle2 each step.",
choices=[1, 2, 4, 8],
default=ServerArgs.speculative_eagle_topk,
)
parser.add_argument(
"--speculative-num-draft-tokens",
type=int,
help="The number of token sampled from draft model in Speculative Decoding.",
help="The number of tokens sampled from the draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_draft_tokens,
)
parser.add_argument(
"--speculative-accept-threshold-single",
type=float,
help="Accept a draft token if its probability in the target model is greater than this threshold.",
default=ServerArgs.speculative_accept_threshold_single,
)
parser.add_argument(
"--speculative-accept-threshold-acc",
type=float,
help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
default=ServerArgs.speculative_accept_threshold_acc,
)
parser.add_argument(
"--speculative-token-map",
type=str,
@@ -949,6 +984,35 @@ class ServerArgs:
help="Enable hierarchical cache",
)
# Server warmups
parser.add_argument(
"--warmups",
type=str,
required=False,
help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
"will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
)
# Debug tensor dumps
parser.add_argument(
"--debug-tensor-dump-output-folder",
type=str,
default=ServerArgs.debug_tensor_dump_output_folder,
help="The output folder for dumping tensors.",
)
parser.add_argument(
"--debug-tensor-dump-input-file",
type=str,
default=ServerArgs.debug_tensor_dump_input_file,
help="The input filename for dumping tensors",
)
parser.add_argument(
"--debug-tensor-dump-inject",
type=str,
default=ServerArgs.debug_tensor_dump_inject,
help="Inject the outputs from jax as the input of every layer.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size

View File

@@ -32,13 +32,15 @@ import socket
import subprocess
import sys
import tempfile
import threading
import time
import warnings
from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from multiprocessing import Pool
from multiprocessing.reduction import ForkingPickler
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
import numpy as np
import psutil
@@ -480,6 +482,10 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""Kill the process and all its child processes."""
# Remove sigchld handler to avoid spammy logs.
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
if parent_pid is None:
parent_pid = os.getpid()
include_parent = False
@@ -499,17 +505,14 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
pass
if include_parent:
if parent_pid == os.getpid():
sys.exit(0)
else:
try:
itself.kill()
try:
itself.kill()
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself.send_signal(signal.SIGQUIT)
except psutil.NoSuchProcess:
pass
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself.send_signal(signal.SIGQUIT)
except psutil.NoSuchProcess:
pass
def monkey_patch_p2p_access_check():
@@ -1215,7 +1218,11 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
def dataclass_to_string_truncated(data, max_length=2048):
def dataclass_to_string_truncated(
data, max_length=2048, skip_names: Optional[Set[str]] = None
):
if skip_names is None:
skip_names = set()
if isinstance(data, str):
if len(data) > max_length:
half_length = max_length // 2
@@ -1234,6 +1241,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
+ ", ".join(
f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
for k, v in data.items()
if k not in skip_names
)
+ "}"
)
@@ -1244,6 +1252,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
+ ", ".join(
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
for f in fields
if f.name not in skip_names
)
+ ")"
)
@@ -1322,9 +1331,9 @@ def pyspy_dump_schedulers():
result = subprocess.run(
cmd, shell=True, capture_output=True, text=True, check=True
)
logger.info(f"Profile for PID {pid}:\n{result.stdout}")
logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}")
except subprocess.CalledProcessError as e:
logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}")
def kill_itself_when_parent_died():
@@ -1448,6 +1457,10 @@ def launch_dummy_health_check_server(host, port):
)
def create_checksum(directory: str):
raise NotImplementedError()
def set_cuda_arch():
if is_flashinfer_available():
capability = torch.cuda.get_device_capability()

View File

@@ -0,0 +1,47 @@
import logging
from typing import List
import numpy as np
import tqdm
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
logger = logging.getLogger(__file__)
_warmup_registry = {}
def warmup(name: str) -> callable:
def decorator(fn: callable):
_warmup_registry[name] = fn
return fn
return decorator
async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager):
for warmup_name in warmup_names:
if warmup_name not in _warmup_registry:
logger.warning(f"Could not find custom warmup {warmup_name}")
continue
logger.info(f"Running warmup {warmup_name}")
await _warmup_registry[warmup_name](tokenizer_manager)
@warmup("voice_chat")
async def voice_chat(tokenizer_manager: TokenizerManager):
# this warms up the fused_moe triton kernels and caches them
# if we don't do this we break real time inference for voice chat
for i in tqdm.trange(1, 512):
size = i * 4
generate_req_input = GenerateReqInput(
input_ids=(np.random.randint(2**16, size=[size])).tolist(),
sampling_params={
"max_new_tokens": 30,
"temperature": 0.8,
"stop_token_ids": [1],
"min_p": 0.0,
},
)
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()

View File

@@ -15,7 +15,7 @@
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@@ -56,6 +56,13 @@ def get_top_logprobs(logits, k):
return logprobs
def get_token_ids_logprobs(logits, token_ids):
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
del logits
logprobs = logprobs[..., token_ids]
return logprobs
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import is_sentence_transformer_model
@@ -84,8 +91,13 @@ class ModelOutput:
output_ids: List[int] = None
top_input_logprobs: List[torch.Tensor] = None
top_output_logprobs: List[torch.Tensor] = None
top_output_logprob_idx: List[List[int]] = None
embed_logits: List[torch.Tensor] = None
scores: List[float] = None
input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
token_ids_input_logprobs: List[torch.Tensor] = None
token_ids_output_logprobs: List[torch.Tensor] = None
class HFRunner:
@@ -157,7 +169,7 @@ class HFRunner:
# Run forward
while True:
prompts, max_new_tokens, lora_paths = in_queue.get()
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
if lora_paths is not None:
assert len(prompts) == len(lora_paths)
@@ -165,16 +177,16 @@ class HFRunner:
if self.model_type == "generation":
out_queue.put(
self.forward_generation_raw(
base_model=self.base_model,
prompts=prompts,
max_new_tokens=max_new_tokens,
base_model=self.base_model,
tokenizer=self.tokenizer,
lora_paths=lora_paths,
torch_dtype=torch_dtype,
output_str_only=self.output_str_only,
token_ids_logprob=token_ids_logprob,
)
)
elif self.model_type == "embedding":
assert not self.output_str_only
logits = self.model.encode(prompts).tolist()
@@ -199,10 +211,11 @@ class HFRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
token_ids_logprob: Optional[int] = None,
):
self.in_queue.put((prompts, max_new_tokens, lora_paths))
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
return self.out_queue.get()
def terminate(self):
@@ -218,17 +231,24 @@ class HFRunner:
@staticmethod
def forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
base_model,
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens: int,
tokenizer,
lora_paths,
torch_dtype: torch.dtype,
output_str_only: bool,
lora_paths: Optional[List[str]] = None,
output_str_only: bool = False,
token_ids_logprob: Optional[int] = None,
) -> ModelOutput:
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
if token_ids_logprob is not None:
token_ids_input_logprobs = []
token_ids_output_logprobs = []
else:
token_ids_input_logprobs = token_ids_output_logprobs = None
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
@@ -275,18 +295,33 @@ class HFRunner:
for logits in outputs.scores
]
)
if token_ids_logprob is not None:
token_ids_output_logprobs.append(
[
get_token_ids_logprobs(
logits[0], token_ids_logprob
).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
)
if token_ids_logprob is not None:
token_ids_input_logprobs.append(
get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
)
del input_logits
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
token_ids_input_logprobs=token_ids_input_logprobs,
token_ids_output_logprobs=token_ids_output_logprobs,
)
@@ -303,11 +338,31 @@ class SRTRunner:
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
chunked_prefill_size: Optional[int] = None,
dp_size: int = 1,
tokenizer_path: Optional[str] = None,
enable_ep_moe: bool = False,
mem_fraction_static: float = 0.65,
trust_remote_code: bool = False,
speculative_draft_model_path: Optional[str] = None,
speculative_algorithm: Optional[str] = None,
speculative_num_steps: Optional[int] = None,
speculative_eagle_topk: Optional[int] = None,
speculative_num_draft_tokens: Optional[int] = None,
disable_overlap_schedule: bool = False,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
enable_dp_attention = dp_size > 1
spec_kwargs = {}
if speculative_draft_model_path:
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
spec_kwargs["speculative_algorithm"] = speculative_algorithm
spec_kwargs["speculative_num_steps"] = speculative_num_steps
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
self.engine = Engine(
model_path=model_path,
tp_size=tp_size,
@@ -321,21 +376,41 @@ class SRTRunner:
lora_backend=lora_backend,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
chunked_prefill_size=chunked_prefill_size,
enable_dp_attention=enable_dp_attention,
dp_size=dp_size,
tokenizer_path=tokenizer_path,
enable_ep_moe=enable_ep_moe,
disable_overlap_schedule=disable_overlap_schedule,
cuda_graph_max_bs=4,
**spec_kwargs,
)
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
if tokenizer_path is None:
self.tokenizer = get_tokenizer(
model_path, trust_remote_code=trust_remote_code
)
else:
self.tokenizer = None
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
logprob_start_len: int = 0,
top_k: Optional[int] = None,
token_ids_logprob: Optional[List[int]] = None,
):
if self.is_generation:
return self.forward_generation_raw(
engine=self.engine,
prompts=prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
engine=self.engine,
logprob_start_len=logprob_start_len,
top_k=top_k,
token_ids_logprob=token_ids_logprob,
)
else:
response = self.engine.encode(prompts)
@@ -358,10 +433,10 @@ class SRTRunner:
"""
if self.is_generation:
return self.batch_forward_generation_raw(
engine=self.engine,
prompts=prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
engine=self.engine,
)
else:
response = self.engine.encode(prompts)
@@ -381,24 +456,43 @@ class SRTRunner:
@staticmethod
def forward_generation_raw(
engine: Engine,
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
lora_paths,
engine,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
logprob_start_len: int = 0,
top_k: Optional[int] = None,
token_ids_logprob: Optional[List[int]] = None,
):
# the return value contains logprobs from prefill
output_strs = []
output_ids = []
# Input logprobs. Note that the last item in input logprob is equivalent to
# the first item in the output logprob.
top_input_logprobs = []
input_token_logprobs_lst = []
top_output_logprobs = []
output_token_logprobs_lst = []
top_output_logprob_idx = []
if token_ids_logprob is not None:
token_ids_input_logprobs = []
token_ids_output_logprobs = []
else:
token_ids_input_logprobs = token_ids_output_logprobs = None
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
if top_k:
sampling_params["top_k"] = top_k
for i, prompt in enumerate(prompts):
response = engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
logprob_start_len=logprob_start_len,
top_logprobs_num=NUM_TOP_LOGPROBS,
token_ids_logprob=token_ids_logprob,
)
text = response["text"]
@@ -408,12 +502,36 @@ class SRTRunner:
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
# output_ids.append(response["output_ids"])
input_token_logprobs = response["meta_info"]["input_token_logprobs"]
output_token_logprobs = response["meta_info"]["output_token_logprobs"]
# print(i, input_token_logprobs)
# print(i, output_token_logprobs)
logprobs = response["meta_info"]["input_top_logprobs"]
if token_ids_logprob is not None:
input_token_ids_logprobs = response["meta_info"][
"input_token_ids_logprobs"
][1:]
else:
input_token_ids_logprobs = None
num_prompt_tokens = response["meta_info"]["prompt_tokens"]
assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
assert len(logprobs) == num_prompt_tokens - logprob_start_len
# The first token logprob has no meaning in sglang.
input_token_logprobs = input_token_logprobs[1:]
logprobs = logprobs[1:]
assert len(input_token_logprobs) == len(logprobs)
input_token_logprobs_lst.append(
input_token_logprobs + [output_token_logprobs[0]]
)
output_token_logprobs_lst.append(output_token_logprobs)
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
[[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
+ [
[
tup[0]
@@ -429,11 +547,41 @@ class SRTRunner:
for x in response["meta_info"]["output_top_logprobs"]
]
)
top_output_logprob_idx.append(
[
[tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
if token_ids_logprob is not None:
token_ids_input_logprobs.append(
[[tup[0] for tup in x] for x in input_token_ids_logprobs]
+ [
[
tup[0]
for tup in response["meta_info"][
"output_token_ids_logprobs"
][0]
]
]
)
token_ids_output_logprobs.append(
[
[tup[0] for tup in x]
for x in response["meta_info"]["output_token_ids_logprobs"]
]
)
return ModelOutput(
output_strs=output_strs,
output_ids=output_ids,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
input_token_logprobs_lst=input_token_logprobs_lst,
output_token_logprobs_lst=output_token_logprobs_lst,
top_output_logprob_idx=top_output_logprob_idx,
token_ids_input_logprobs=token_ids_input_logprobs,
token_ids_output_logprobs=token_ids_output_logprobs,
)
@staticmethod

View File

@@ -0,0 +1,88 @@
"""
Run one test prompt.
Usage:
python3 -m sglang.test.send_one
"""
import argparse
import json
import requests
def send_one_prompt(args):
if args.image:
args.prompt = (
"Human: Describe this image in a very short sentence.\n\nAssistant:"
)
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
else:
image_data = None
response = requests.post(
"http://localhost:30000/generate",
json={
"text": args.prompt,
"image_data": image_data,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
},
stream=args.stream,
)
if args.stream:
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
ret = json.loads(chunk[5:].strip("\n"))
else:
ret = response.json()
latency = ret["meta_info"]["e2e_latency"]
if "spec_verify_ct" in ret["meta_info"]:
acc_length = (
ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0
speed = ret["meta_info"]["completion_tokens"] / latency
print(ret["text"])
print()
print(f"{acc_length=:.2f}")
print(f"{speed=:.2f} token/s")
return acc_length, speed
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--frequency-penalty", type=float, default=0.0)
parser.add_argument("--presence-penalty", type=float, default=0.0)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--prompt",
type=str,
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
)
parser.add_argument(
"--image",
action="store_true",
)
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
send_one_prompt(args)

View File

@@ -8,10 +8,11 @@ import random
import subprocess
import threading
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from types import SimpleNamespace
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple
import numpy as np
import requests
@@ -408,26 +409,49 @@ def popen_launch_server(
other_args: list[str] = (),
env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
pd_seperated: bool = False,
):
_, host, port = base_url.split(":")
host = host[2:]
if pd_seperated:
command = "sglang.launch_pd_server"
else:
command = "sglang.launch_server"
command = [
"python3",
"-m",
"sglang.launch_server",
command,
"--model-path",
model,
"--host",
host,
"--port",
port,
*other_args,
*[str(x) for x in other_args],
]
if pd_seperated:
command.extend(
[
"--lb-host",
host,
"--lb-port",
port,
]
)
else:
command.extend(
[
"--host",
host,
"--port",
port,
]
)
if api_key:
command += ["--api-key", api_key]
print(f"command={' '.join(command)}")
if return_stdout_stderr:
process = subprocess.Popen(
command,
@@ -456,6 +480,8 @@ def popen_launch_server(
except requests.RequestException:
pass
time.sleep(10)
kill_process_tree(process.pid)
raise TimeoutError("Server failed to start within the timeout period.")
@@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
success = True
for filename in files:
global process
process = None
def run_one_file(filename):
nonlocal process
filename = os.path.join(os.getcwd(), filename)
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
process = subprocess.Popen(
@@ -534,11 +562,14 @@ def get_benchmark_args(
dataset_path="",
tokenizer="",
num_prompts=500,
sharegpt_output_len=None,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
request_rate=float("inf"),
disable_stream=False,
disable_ignore_eos=False,
pd_seperated: bool = False,
):
return SimpleNamespace(
backend="sglang",
@@ -550,8 +581,8 @@ def get_benchmark_args(
model=None,
tokenizer=tokenizer,
num_prompts=num_prompts,
sharegpt_output_len=None,
sharegpt_context_len=None,
sharegpt_output_len=sharegpt_output_len,
sharegpt_context_len=sharegpt_context_len,
random_input_len=random_input_len,
random_output_len=random_output_len,
random_range_ratio=0.0,
@@ -567,6 +598,8 @@ def get_benchmark_args(
apply_chat_template=False,
profile=None,
lora_name=None,
prompt_suffix="",
pd_seperated=pd_seperated,
)
@@ -580,6 +613,7 @@ def run_bench_serving(
tokenizer=None,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
disable_stream=False,
disable_ignore_eos=False,
need_warmup=False,
@@ -602,6 +636,7 @@ def run_bench_serving(
num_prompts=num_prompts,
random_input_len=random_input_len,
random_output_len=random_output_len,
sharegpt_context_len=sharegpt_context_len,
request_rate=request_rate,
disable_stream=disable_stream,
disable_ignore_eos=disable_ignore_eos,
@@ -626,6 +661,7 @@ def run_bench_serving_multi(
other_server_args,
benchmark_args,
need_warmup=False,
pd_seperated=False,
):
# Launch the server
process = popen_launch_server(
@@ -633,6 +669,7 @@ def run_bench_serving_multi(
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args,
pd_seperated=pd_seperated,
)
# run benchmark for all
@@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args):
"128",
"--output",
"8",
*other_args,
*[str(x) for x in other_args],
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None):
stdout = open(STDOUT_FILENAME, "w")
stderr = open(STDERR_FILENAME, "w")
process = subprocess.Popen(
command, stdout=stdout, stderr=stderr, env=env, text=True
command, stdout=stdout, stderr=stdout, env=env, text=True
)
# Launch a thread to stream the output
@@ -914,3 +951,78 @@ def run_mulit_request_test(
def write_github_step_summary(content):
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
f.write(content)
def run_logprob_check(self: unittest.TestCase, arg: Tuple):
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
) = arg
input_ids = list(range(input_len))
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
"ignore_eos": True,
},
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
},
)
response_json = response.json()
res = response_json
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
# Test the number of tokens are correct
if return_logprob:
self.assertEqual(
len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
if top_logprobs_num:
self.assertEqual(
len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)
for i in range(output_len):
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"][i]),
top_logprobs_num,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if temperature == 0:
rank = 0
while rank < len(res["meta_info"]["output_top_logprobs"][i]):
try:
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][rank],
)
break
except AssertionError:
# There's a tie. Allow the second item in this case.
if (
res["meta_info"]["output_top_logprobs"][i][rank][0]
== res["meta_info"]["output_top_logprobs"][i][rank + 1][
0
]
):
rank += 1
else:
raise