Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -30,11 +30,20 @@ def get_model_config(model_name: str, tp_size: int):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] in [
|
||||||
|
"Grok1ForCausalLM",
|
||||||
|
"Grok1ImgGen",
|
||||||
|
"Grok1AForCausalLM",
|
||||||
|
]:
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
E = config.n_routed_experts
|
E = config.n_routed_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
else:
|
else:
|
||||||
# Default: Mixtral
|
# Default: Mixtral
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
|
|||||||
@@ -35,6 +35,15 @@ def get_model_config(model_name: str, tp_size: int):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
|
elif config.architectures[0] in [
|
||||||
|
"Grok1ForCausalLM",
|
||||||
|
"Grok1ImgGen",
|
||||||
|
"Grok1AForCausalLM",
|
||||||
|
]:
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||||
else:
|
else:
|
||||||
# Default: Mixtral
|
# Default: Mixtral
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
|
|||||||
@@ -397,6 +397,15 @@ def main(args: argparse.Namespace):
|
|||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] in [
|
||||||
|
"Grok1ForCausalLM",
|
||||||
|
"Grok1ImgGen",
|
||||||
|
"Grok1AForCausalLM",
|
||||||
|
]:
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
else:
|
else:
|
||||||
# Default: Mixtral
|
# Default: Mixtral
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
|
|||||||
@@ -210,8 +210,7 @@
|
|||||||
"response = requests.post(url, json=data)\n",
|
"response = requests.post(url, json=data)\n",
|
||||||
"print_highlight(response.text)\n",
|
"print_highlight(response.text)\n",
|
||||||
"assert response.json()[\"success\"] is True\n",
|
"assert response.json()[\"success\"] is True\n",
|
||||||
"assert response.json()[\"message\"] == \"Succeeded to update model weights.\"\n",
|
"assert response.json()[\"message\"] == \"Succeeded to update model weights.\""
|
||||||
"assert response.json().keys() == {\"success\", \"message\"}"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -411,7 +410,7 @@
|
|||||||
" },\n",
|
" },\n",
|
||||||
")\n",
|
")\n",
|
||||||
"output = response.json()\n",
|
"output = response.json()\n",
|
||||||
"output_tokens = output[\"token_ids\"]\n",
|
"output_tokens = output[\"output_ids\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n",
|
"output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n",
|
||||||
"print_highlight(f\"Tokenized Output: {output_tokens}\")\n",
|
"print_highlight(f\"Tokenized Output: {output_tokens}\")\n",
|
||||||
|
|||||||
@@ -96,7 +96,6 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
* `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine.
|
* `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine.
|
||||||
* `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance.
|
* `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance.
|
||||||
* `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU.
|
* `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU.
|
||||||
* `prefill_only_one_req`: When this flag is turned on, the engine prefills only one request at a time.
|
|
||||||
|
|
||||||
## Other runtime options
|
## Other runtime options
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,10 @@ dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
|
|||||||
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
"sglang" = ["srt/layers/moe/fused_moe_triton/configs/*.json", "srt/layers/quantization/configs/*.json"]
|
"sglang" = [
|
||||||
|
"srt/layers/moe/fused_moe_triton/configs/*.json",
|
||||||
|
"srt/layers/quantization/configs/*.json",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
|||||||
@@ -8,8 +8,10 @@
|
|||||||
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
|
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
|
||||||
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
|
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
|
||||||
- `bench_serving.py`: Benchmark online serving with dynamic requests.
|
- `bench_serving.py`: Benchmark online serving with dynamic requests.
|
||||||
- `check_env.py`: Check the environment variables.
|
- `check_env.py`: Check the environment variables and dependencies.
|
||||||
- `global_config.py`: The global configs and constants.
|
- `global_config.py`: The global configs and constants.
|
||||||
- `launch_server.py`: The entry point for launching the local server.
|
- `launch_server.py`: The entry point for launching the local server.
|
||||||
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
|
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
|
||||||
|
- `profiler.py`: Profile a running server.
|
||||||
- `utils.py`: Common utilities.
|
- `utils.py`: Common utilities.
|
||||||
|
- `version.py`: Version info.
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class BenchArgs:
|
|||||||
profile: bool = False
|
profile: bool = False
|
||||||
skip_warmup: bool = False
|
skip_warmup: bool = False
|
||||||
do_not_exit: bool = False
|
do_not_exit: bool = False
|
||||||
|
prompt_suffix: str = ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
@@ -177,6 +178,12 @@ class BenchArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-suffix",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
@@ -216,6 +223,10 @@ def throughput_test_once(
|
|||||||
]
|
]
|
||||||
|
|
||||||
if profile:
|
if profile:
|
||||||
|
assert (
|
||||||
|
"SGLANG_TORCH_PROFILER_DIR" in os.environ
|
||||||
|
), "Please set SGLANG_TORCH_PROFILER_DIR."
|
||||||
|
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
|
||||||
backend.start_profile()
|
backend.start_profile()
|
||||||
|
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
@@ -229,6 +240,8 @@ def throughput_test_once(
|
|||||||
if backend_name == "runtime":
|
if backend_name == "runtime":
|
||||||
gen_out = json.loads(gen_out)
|
gen_out = json.loads(gen_out)
|
||||||
|
|
||||||
|
server_info = backend.get_server_info()
|
||||||
|
|
||||||
measurement_results["total_latency"] = latency
|
measurement_results["total_latency"] = latency
|
||||||
measurement_results["total_output_tokens"] = sum(
|
measurement_results["total_output_tokens"] = sum(
|
||||||
o["meta_info"]["completion_tokens"] for o in gen_out
|
o["meta_info"]["completion_tokens"] for o in gen_out
|
||||||
@@ -246,6 +259,7 @@ def throughput_test_once(
|
|||||||
measurement_results["total_input_tokens"]
|
measurement_results["total_input_tokens"]
|
||||||
+ measurement_results["total_output_tokens"]
|
+ measurement_results["total_output_tokens"]
|
||||||
) / latency
|
) / latency
|
||||||
|
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
|
||||||
|
|
||||||
return measurement_results
|
return measurement_results
|
||||||
|
|
||||||
@@ -361,6 +375,11 @@ def throughput_test(
|
|||||||
print(
|
print(
|
||||||
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
"Last generation throughput (tok/s):", result["last_gen_throughput"]
|
||||||
|
)
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
"{:<40} {:<10.2f}".format(
|
"{:<40} {:<10.2f}".format(
|
||||||
"Request throughput (req/s):", result["request_throughput"]
|
"Request throughput (req/s):", result["request_throughput"]
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ Usage:
|
|||||||
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
||||||
|
|
||||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
|
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
|
||||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -71,6 +70,10 @@ def remove_prefix(text: str, prefix: str) -> str:
|
|||||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||||
|
|
||||||
|
|
||||||
|
def remove_suffix(text: str, suffix: str) -> str:
|
||||||
|
return text[: -len(suffix)] if text.endswith(suffix) else text
|
||||||
|
|
||||||
|
|
||||||
def get_auth_headers() -> Dict[str, str]:
|
def get_auth_headers() -> Dict[str, str]:
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
if api_key:
|
if api_key:
|
||||||
@@ -79,7 +82,7 @@ def get_auth_headers() -> Dict[str, str]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
# trt llm not support ignore_eos
|
# trt llm does not support ignore_eos
|
||||||
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
||||||
async def async_request_trt_llm(
|
async def async_request_trt_llm(
|
||||||
request_func_input: RequestFuncInput,
|
request_func_input: RequestFuncInput,
|
||||||
@@ -179,6 +182,7 @@ async def async_request_openai_completions(
|
|||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
|
output_len = request_func_input.output_len
|
||||||
ttft = 0.0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
@@ -215,11 +219,14 @@ async def async_request_openai_completions(
|
|||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
generated_text += data["choices"][0]["text"]
|
generated_text += data["choices"][0]["text"]
|
||||||
|
output_len = data.get("usage", {}).get(
|
||||||
|
"completion_tokens", output_len
|
||||||
|
)
|
||||||
|
|
||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.success = True
|
output.success = True
|
||||||
output.latency = latency
|
output.latency = latency
|
||||||
output.output_len = request_func_input.output_len
|
output.output_len = output_len
|
||||||
else:
|
else:
|
||||||
output.error = response.reason or ""
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
@@ -339,9 +346,11 @@ async def async_request_sglang_generate(
|
|||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
|
output_len = request_func_input.output_len
|
||||||
ttft = 0.0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
|
last_output_len = 0
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
url=api_url, json=payload, headers=headers
|
url=api_url, json=payload, headers=headers
|
||||||
@@ -365,6 +374,9 @@ async def async_request_sglang_generate(
|
|||||||
# want to check a token was generated
|
# want to check a token was generated
|
||||||
if data["text"]:
|
if data["text"]:
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
|
generated_text = data["text"]
|
||||||
|
output_len = data["meta_info"]["completion_tokens"]
|
||||||
|
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0.0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
@@ -372,7 +384,13 @@ async def async_request_sglang_generate(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp - most_recent_timestamp)
|
num_new_tokens = output_len - last_output_len
|
||||||
|
if num_new_tokens == 0:
|
||||||
|
continue
|
||||||
|
adjust_itl = (
|
||||||
|
timestamp - most_recent_timestamp
|
||||||
|
) / num_new_tokens
|
||||||
|
output.itl.extend([adjust_itl] * num_new_tokens)
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
generated_text = data["text"]
|
generated_text = data["text"]
|
||||||
@@ -380,7 +398,7 @@ async def async_request_sglang_generate(
|
|||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.success = True
|
output.success = True
|
||||||
output.latency = latency
|
output.latency = latency
|
||||||
output.output_len = request_func_input.output_len
|
output.output_len = output_len
|
||||||
else:
|
else:
|
||||||
output.error = response.reason or ""
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
@@ -388,6 +406,7 @@ async def async_request_sglang_generate(
|
|||||||
output.success = False
|
output.success = False
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
output.error = "".join(traceback.format_exception(*exc_info))
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
print(f"{output.error=}")
|
||||||
|
|
||||||
if pbar:
|
if pbar:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
@@ -461,6 +480,7 @@ def get_dataset(args, tokenizer):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
fixed_output_len=args.sharegpt_output_len,
|
fixed_output_len=args.sharegpt_output_len,
|
||||||
context_len=args.sharegpt_context_len,
|
context_len=args.sharegpt_context_len,
|
||||||
|
prompt_suffix=args.prompt_suffix,
|
||||||
apply_chat_template=args.apply_chat_template,
|
apply_chat_template=args.apply_chat_template,
|
||||||
)
|
)
|
||||||
elif args.dataset_name == "random":
|
elif args.dataset_name == "random":
|
||||||
@@ -521,7 +541,9 @@ class BenchmarkMetrics:
|
|||||||
mean_itl_ms: float
|
mean_itl_ms: float
|
||||||
median_itl_ms: float
|
median_itl_ms: float
|
||||||
std_itl_ms: float
|
std_itl_ms: float
|
||||||
|
p95_itl_ms: float
|
||||||
p99_itl_ms: float
|
p99_itl_ms: float
|
||||||
|
max_itl_ms: float
|
||||||
mean_e2e_latency_ms: float
|
mean_e2e_latency_ms: float
|
||||||
median_e2e_latency_ms: float
|
median_e2e_latency_ms: float
|
||||||
std_e2e_latency_ms: float
|
std_e2e_latency_ms: float
|
||||||
@@ -572,6 +594,7 @@ def sample_sharegpt_requests(
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
fixed_output_len: Optional[int] = None,
|
fixed_output_len: Optional[int] = None,
|
||||||
context_len: Optional[int] = None,
|
context_len: Optional[int] = None,
|
||||||
|
prompt_suffix: Optional[str] = "",
|
||||||
apply_chat_template=False,
|
apply_chat_template=False,
|
||||||
) -> List[Tuple[str, int, int]]:
|
) -> List[Tuple[str, int, int]]:
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
@@ -584,11 +607,19 @@ def sample_sharegpt_requests(
|
|||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
|
|
||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
dataset = [
|
||||||
|
data
|
||||||
|
for data in dataset
|
||||||
|
if len(data.get("conversations", data.get("conversation", []))) >= 2
|
||||||
|
]
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
(
|
||||||
|
data.get("conversations", data.get("conversation", []))[0]["value"],
|
||||||
|
data.get("conversations", data.get("conversation", []))[1]["value"],
|
||||||
|
)
|
||||||
for data in dataset
|
for data in dataset
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -603,6 +634,8 @@ def sample_sharegpt_requests(
|
|||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Tokenize the prompts and completions.
|
||||||
prompt = dataset[i][0]
|
prompt = dataset[i][0]
|
||||||
|
if prompt_suffix:
|
||||||
|
prompt = prompt
|
||||||
|
|
||||||
if apply_chat_template:
|
if apply_chat_template:
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
@@ -666,10 +699,17 @@ def sample_random_requests(
|
|||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
dataset = [
|
||||||
|
data
|
||||||
|
for data in dataset
|
||||||
|
if len(data.get("conversations", data.get("conversation", []))) >= 2
|
||||||
|
]
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
(
|
||||||
|
data.get("conversations", data.get("conversation", []))[0]["value"],
|
||||||
|
data.get("conversations", data.get("conversation", []))[1]["value"],
|
||||||
|
)
|
||||||
for data in dataset
|
for data in dataset
|
||||||
]
|
]
|
||||||
# Shuffle the dataset.
|
# Shuffle the dataset.
|
||||||
@@ -895,7 +935,9 @@ def calculate_metrics(
|
|||||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||||
median_itl_ms=np.median(itls or 0) * 1000,
|
median_itl_ms=np.median(itls or 0) * 1000,
|
||||||
std_itl_ms=np.std(itls or 0) * 1000,
|
std_itl_ms=np.std(itls or 0) * 1000,
|
||||||
|
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
|
||||||
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
||||||
|
max_itl_ms=np.max(itls or 0) * 1000,
|
||||||
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
|
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
|
||||||
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
||||||
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
|
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
|
||||||
@@ -919,6 +961,7 @@ async def benchmark(
|
|||||||
lora_name: str,
|
lora_name: str,
|
||||||
extra_request_body: Dict[str, Any],
|
extra_request_body: Dict[str, Any],
|
||||||
profile: bool,
|
profile: bool,
|
||||||
|
pd_seperated: bool = False,
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@@ -1004,6 +1047,17 @@ async def benchmark(
|
|||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
|
if "sglang" in backend:
|
||||||
|
server_info = requests.get(base_url + "/get_server_info")
|
||||||
|
if pd_seperated:
|
||||||
|
accept_length = server_info.json()["decode"][0].get(
|
||||||
|
"avg_spec_accept_length", None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
accept_length = server_info.json().get("avg_spec_accept_length", None)
|
||||||
|
else:
|
||||||
|
accept_length = None
|
||||||
|
|
||||||
# Compute metrics and print results
|
# Compute metrics and print results
|
||||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||||
metrics, output_lens = calculate_metrics(
|
metrics, output_lens = calculate_metrics(
|
||||||
@@ -1053,6 +1107,8 @@ async def benchmark(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
|
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
|
||||||
|
if accept_length:
|
||||||
|
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
|
||||||
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
||||||
print(
|
print(
|
||||||
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
||||||
@@ -1066,16 +1122,12 @@ async def benchmark(
|
|||||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||||
print(
|
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
||||||
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
|
||||||
)
|
|
||||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
|
||||||
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
|
|
||||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
|
||||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||||
|
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1117,8 +1169,10 @@ async def benchmark(
|
|||||||
"mean_itl_ms": metrics.mean_itl_ms,
|
"mean_itl_ms": metrics.mean_itl_ms,
|
||||||
"median_itl_ms": metrics.median_itl_ms,
|
"median_itl_ms": metrics.median_itl_ms,
|
||||||
"std_itl_ms": metrics.std_itl_ms,
|
"std_itl_ms": metrics.std_itl_ms,
|
||||||
|
"p95_itl_ms": metrics.p95_itl_ms,
|
||||||
"p99_itl_ms": metrics.p99_itl_ms,
|
"p99_itl_ms": metrics.p99_itl_ms,
|
||||||
"concurrency": metrics.concurrency,
|
"concurrency": metrics.concurrency,
|
||||||
|
"accept_length": accept_length,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print(f"Error running benchmark for request rate: {request_rate}")
|
print(f"Error running benchmark for request rate: {request_rate}")
|
||||||
@@ -1151,14 +1205,6 @@ async def benchmark(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def parse_request_rate_range(request_rate_range):
|
|
||||||
if len(request_rate_range.split(",")) == 3:
|
|
||||||
start, stop, step = map(int, request_rate_range.split(","))
|
|
||||||
return list(range(start, stop, step))
|
|
||||||
else:
|
|
||||||
return list(map(int, request_rate_range.split(",")))
|
|
||||||
|
|
||||||
|
|
||||||
def check_chat_template(model_path):
|
def check_chat_template(model_path):
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
@@ -1168,6 +1214,12 @@ def check_chat_template(model_path):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_args(args_: argparse.Namespace):
|
||||||
|
"""Set the global args."""
|
||||||
|
global args
|
||||||
|
args = args_
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(args_: argparse.Namespace):
|
def run_benchmark(args_: argparse.Namespace):
|
||||||
global args
|
global args
|
||||||
args = args_
|
args = args_
|
||||||
@@ -1176,6 +1228,8 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
if not hasattr(args, "max_concurrency"):
|
if not hasattr(args, "max_concurrency"):
|
||||||
args.max_concurrency = None
|
args.max_concurrency = None
|
||||||
|
|
||||||
|
print(f"benchmark_args={args}")
|
||||||
|
|
||||||
# Set global environments
|
# Set global environments
|
||||||
set_ulimit()
|
set_ulimit()
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
@@ -1272,12 +1326,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
backend = args.backend
|
backend = args.backend
|
||||||
model_id = args.model
|
model_id = args.model
|
||||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||||
|
|
||||||
tokenizer = get_tokenizer(tokenizer_id)
|
tokenizer = get_tokenizer(tokenizer_id)
|
||||||
|
|
||||||
input_requests = get_dataset(args, tokenizer)
|
input_requests = get_dataset(args, tokenizer)
|
||||||
|
|
||||||
if not args.multi:
|
|
||||||
return asyncio.run(
|
return asyncio.run(
|
||||||
benchmark(
|
benchmark(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
@@ -1292,27 +1343,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
lora_name=args.lora_name,
|
lora_name=args.lora_name,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
profile=args.profile,
|
profile=args.profile,
|
||||||
)
|
pd_seperated=args.pd_seperated,
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
|
|
||||||
request_rates = parse_request_rate_range(args.request_rate_range)
|
|
||||||
|
|
||||||
for rate in request_rates:
|
|
||||||
asyncio.run(
|
|
||||||
benchmark(
|
|
||||||
backend=backend,
|
|
||||||
api_url=api_url,
|
|
||||||
base_url=base_url,
|
|
||||||
model_id=model_id,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
input_requests=input_requests,
|
|
||||||
request_rate=rate,
|
|
||||||
max_concurrency=args.max_concurrency,
|
|
||||||
disable_tqdm=args.disable_tqdm,
|
|
||||||
lora_name=args.lora_name,
|
|
||||||
extra_request_body=extra_request_body,
|
|
||||||
profile=args.profile,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1428,17 +1459,6 @@ if __name__ == "__main__":
|
|||||||
"actual request rate may be lower than specified with --request-rate, "
|
"actual request rate may be lower than specified with --request-rate, "
|
||||||
"if the server is not processing requests fast enough to keep up.",
|
"if the server is not processing requests fast enough to keep up.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--multi",
|
|
||||||
action="store_true",
|
|
||||||
help="Use request rate range rather than single value.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--request-rate-range",
|
|
||||||
type=str,
|
|
||||||
default="2,34,2",
|
|
||||||
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-tqdm",
|
"--disable-tqdm",
|
||||||
@@ -1485,6 +1505,17 @@ if __name__ == "__main__":
|
|||||||
default=None,
|
default=None,
|
||||||
help="The name of LoRA adapter",
|
help="The name of LoRA adapter",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-suffix",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pd-seperated",
|
||||||
|
action="store_true",
|
||||||
|
help="Benchmark PD disaggregation server",
|
||||||
|
)
|
||||||
|
|
||||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
|
|||||||
@@ -34,11 +34,9 @@ class GlobalConfig:
|
|||||||
self.skip_special_tokens_in_output = True
|
self.skip_special_tokens_in_output = True
|
||||||
self.spaces_between_special_tokens_in_out = True
|
self.spaces_between_special_tokens_in_out = True
|
||||||
|
|
||||||
# Interpreter optimization configs
|
# Language frontend interpreter optimization configs
|
||||||
self.enable_precache_with_tracing = True
|
self.enable_precache_with_tracing = True
|
||||||
self.enable_parallel_encoding = True
|
self.enable_parallel_encoding = True
|
||||||
|
|
||||||
self.enable_flashinfer_mla = False
|
|
||||||
|
|
||||||
|
|
||||||
global_config = GlobalConfig()
|
global_config = GlobalConfig()
|
||||||
|
|||||||
@@ -329,7 +329,12 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
|
|
||||||
def compute_normalized_prompt_logprobs(input_logprobs):
|
def compute_normalized_prompt_logprobs(input_logprobs):
|
||||||
values = [x[0] for x in input_logprobs if x[0]]
|
values = [x[0] for x in input_logprobs if x[0]]
|
||||||
|
try:
|
||||||
return sum(values) / len(values)
|
return sum(values) / len(values)
|
||||||
|
except TypeError:
|
||||||
|
print(f"{input_logprobs=}", flush=True)
|
||||||
|
print(f"{input_logprobs[0]=}", flush=True)
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
|
||||||
class Runtime:
|
class Runtime:
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ class LoadFormat(str, enum.Enum):
|
|||||||
BITSANDBYTES = "bitsandbytes"
|
BITSANDBYTES = "bitsandbytes"
|
||||||
MISTRAL = "mistral"
|
MISTRAL = "mistral"
|
||||||
LAYERED = "layered"
|
LAYERED = "layered"
|
||||||
|
JAX = "jax"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -42,13 +43,15 @@ class LoadConfig:
|
|||||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||||
Default to "original/**/*" to avoid repeated loading of llama's
|
Default to "original/**/*" to avoid repeated loading of llama's
|
||||||
checkpoints.
|
checkpoints.
|
||||||
|
decryption_key_file: If set, decrypts the output files with a password read
|
||||||
|
from this file (after PBKDF2).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||||
download_dir: Optional[str] = None
|
download_dir: Optional[str] = None
|
||||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||||
|
decryption_key_file: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class ModelConfig:
|
|||||||
is_embedding: Optional[bool] = None,
|
is_embedding: Optional[bool] = None,
|
||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
quantization: Optional[str] = None,
|
quantization: Optional[str] = None,
|
||||||
|
override_config_file: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
@@ -51,11 +52,16 @@ class ModelConfig:
|
|||||||
|
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_override_args = json.loads(model_override_args)
|
self.model_override_args = json.loads(model_override_args)
|
||||||
|
kwargs = {}
|
||||||
|
if override_config_file and override_config_file.strip():
|
||||||
|
kwargs["_configuration_file"] = override_config_file.strip()
|
||||||
|
|
||||||
self.hf_config = get_config(
|
self.hf_config = get_config(
|
||||||
model_path,
|
model_path,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
model_override_args=self.model_override_args,
|
model_override_args=self.model_override_args,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
|
|
||||||
@@ -64,6 +70,9 @@ class ModelConfig:
|
|||||||
self.hf_config.architectures, is_embedding
|
self.hf_config.architectures, is_embedding
|
||||||
)
|
)
|
||||||
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||||
|
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
|
||||||
|
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
|
||||||
|
self.is_audio_model = is_audio_model(self.hf_config.architectures)
|
||||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
|
|
||||||
@@ -71,7 +80,9 @@ class ModelConfig:
|
|||||||
derived_context_len = get_context_length(self.hf_text_config)
|
derived_context_len = get_context_length(self.hf_text_config)
|
||||||
if context_length is not None:
|
if context_length is not None:
|
||||||
if context_length > derived_context_len:
|
if context_length > derived_context_len:
|
||||||
if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"):
|
if get_bool_env_var(
|
||||||
|
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
||||||
f"This may lead to incorrect model outputs or CUDA errors."
|
f"This may lead to incorrect model outputs or CUDA errors."
|
||||||
@@ -416,6 +427,8 @@ def is_multimodal_model(model_architectures: List[str]):
|
|||||||
or "LlavaQwenForCausalLM" in model_architectures
|
or "LlavaQwenForCausalLM" in model_architectures
|
||||||
or "LlavaMistralForCausalLM" in model_architectures
|
or "LlavaMistralForCausalLM" in model_architectures
|
||||||
or "LlavaVidForCausalLM" in model_architectures
|
or "LlavaVidForCausalLM" in model_architectures
|
||||||
|
or "Grok1VForCausalLM" in model_architectures
|
||||||
|
or "Grok1AForCausalLM" in model_architectures
|
||||||
or "MllamaForConditionalGeneration" in model_architectures
|
or "MllamaForConditionalGeneration" in model_architectures
|
||||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||||
or "Qwen2_5_VLForConditionalGeneration" in model_architectures
|
or "Qwen2_5_VLForConditionalGeneration" in model_architectures
|
||||||
@@ -426,6 +439,18 @@ def is_multimodal_model(model_architectures: List[str]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_multimodal_gen_model(model_architectures: List[str]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_image_gen_model(model_architectures: List[str]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_audio_model(model_architectures: List[str]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||||
return "MllamaForConditionalGeneration" in model_architectures
|
return "MllamaForConditionalGeneration" in model_architectures
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xgrammar import (
|
from xgrammar import (
|
||||||
@@ -42,11 +42,16 @@ MAX_ROLLBACK_TOKENS = 200
|
|||||||
class XGrammarGrammar(BaseGrammarObject):
|
class XGrammarGrammar(BaseGrammarObject):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
|
self,
|
||||||
|
matcher: GrammarMatcher,
|
||||||
|
vocab_size: int,
|
||||||
|
ctx: CompiledGrammar,
|
||||||
|
override_stop_tokens: Optional[Union[List[int], int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.matcher = matcher
|
self.matcher = matcher
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
self.override_stop_tokens = override_stop_tokens
|
||||||
self.finished = False
|
self.finished = False
|
||||||
|
|
||||||
def accept_token(self, token: int):
|
def accept_token(self, token: int):
|
||||||
@@ -96,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject):
|
|||||||
apply_token_bitmask_inplace(logits, vocab_mask)
|
apply_token_bitmask_inplace(logits, vocab_mask)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
matcher = GrammarMatcher(
|
||||||
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
self.ctx,
|
||||||
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||||
|
override_stop_tokens=self.override_stop_tokens,
|
||||||
|
)
|
||||||
|
return XGrammarGrammar(
|
||||||
|
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class XGrammarGrammarBackend(BaseGrammarBackend):
|
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||||
@@ -111,8 +122,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||||
tokenizer, vocab_size=vocab_size
|
tokenizer, vocab_size=vocab_size
|
||||||
)
|
)
|
||||||
|
override_stop_tokens = None
|
||||||
|
|
||||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self.override_stop_tokens = override_stop_tokens
|
||||||
|
|
||||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||||
|
|
||||||
@@ -161,7 +175,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
|
|
||||||
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||||
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
if self.grammar_compiler:
|
if self.grammar_compiler:
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class Engine:
|
|||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
|
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||||
return_hidden_states: bool = False,
|
return_hidden_states: bool = False,
|
||||||
@@ -142,6 +143,7 @@ class Engine:
|
|||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
logprob_start_len=logprob_start_len,
|
logprob_start_len=logprob_start_len,
|
||||||
top_logprobs_num=top_logprobs_num,
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
token_ids_logprob=token_ids_logprob,
|
||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
modalities=modalities_list,
|
modalities=modalities_list,
|
||||||
custom_logit_processor=custom_logit_processor,
|
custom_logit_processor=custom_logit_processor,
|
||||||
@@ -179,6 +181,7 @@ class Engine:
|
|||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
|
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
@@ -195,6 +198,7 @@ class Engine:
|
|||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
logprob_start_len=logprob_start_len,
|
logprob_start_len=logprob_start_len,
|
||||||
top_logprobs_num=top_logprobs_num,
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
token_ids_logprob=token_ids_logprob,
|
||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
custom_logit_processor=custom_logit_processor,
|
custom_logit_processor=custom_logit_processor,
|
||||||
@@ -226,15 +230,22 @@ class Engine:
|
|||||||
kill_process_tree(os.getpid(), include_parent=False)
|
kill_process_tree(os.getpid(), include_parent=False)
|
||||||
|
|
||||||
def start_profile(self):
|
def start_profile(self):
|
||||||
self.tokenizer_manager.start_profile()
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
||||||
|
|
||||||
def stop_profile(self):
|
def stop_profile(self):
|
||||||
self.tokenizer_manager.stop_profile()
|
self.tokenizer_manager.stop_profile()
|
||||||
|
|
||||||
def get_server_info(self):
|
def get_server_info(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
internal_states = loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.get_internal_state()
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
**dataclasses.asdict(self.tokenizer_manager.server_args), # server args
|
**dataclasses.asdict(self.tokenizer_manager.server_args),
|
||||||
**self.scheduler_info,
|
**self.scheduler_info,
|
||||||
|
**internal_states,
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,6 +334,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
||||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||||
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
||||||
|
|
||||||
# Set prometheus env vars
|
# Set prometheus env vars
|
||||||
if server_args.enable_metrics:
|
if server_args.enable_metrics:
|
||||||
@@ -346,12 +358,23 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sigchld_handler(signum, frame):
|
||||||
|
pid, exitcode = os.waitpid(0, os.WNOHANG)
|
||||||
|
if exitcode != 0:
|
||||||
|
logger.warning(
|
||||||
|
"Child process unexpectedly failed with an exit code %d. pid=%d",
|
||||||
|
exitcode,
|
||||||
|
pid,
|
||||||
|
)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGCHLD, sigchld_handler)
|
||||||
|
|
||||||
# Register the signal handler.
|
# Register the signal handler.
|
||||||
# The child processes will send SIGQUIT to this process when any error happens
|
# The child processes will send SIGQUIT to this process when any error happens
|
||||||
# This process then clean up the whole process tree
|
# This process then clean up the whole process tree
|
||||||
def sigquit_handler(signum, frame):
|
def sigquit_handler(signum, frame):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Received sigquit from a child proces. It usually means the child failed."
|
"Received sigquit from a child process. It usually means the child failed."
|
||||||
)
|
)
|
||||||
kill_process_tree(os.getpid())
|
kill_process_tree(os.getpid())
|
||||||
|
|
||||||
|
|||||||
@@ -25,11 +25,14 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import AsyncIterator, Dict, Optional
|
from typing import AsyncIterator, Callable, Dict, Optional
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import orjson
|
import orjson
|
||||||
import requests
|
import requests
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -49,8 +52,10 @@ from sglang.srt.managers.io_struct import (
|
|||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
ParseFunctionCallReq,
|
ParseFunctionCallReq,
|
||||||
|
ProfileReqInput,
|
||||||
ReleaseMemoryOccupationReqInput,
|
ReleaseMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
|
SetInternalStateReq,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
VertexGenerateReqInput,
|
VertexGenerateReqInput,
|
||||||
@@ -78,22 +83,13 @@ from sglang.srt.utils import (
|
|||||||
kill_process_tree,
|
kill_process_tree,
|
||||||
set_uvicorn_logging_configs,
|
set_uvicorn_logging_configs,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.warmup import execute_warmups
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
from sglang.version import __version__
|
from sglang.version import __version__
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
# Fast API
|
|
||||||
app = FastAPI()
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Store global states
|
# Store global states
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -110,6 +106,34 @@ def set_global_state(global_state: _GlobalState):
|
|||||||
_global_state = global_state
|
_global_state = global_state
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(fast_api_app: FastAPI):
|
||||||
|
server_args: ServerArgs = fast_api_app.server_args
|
||||||
|
if server_args.warmups is not None:
|
||||||
|
await execute_warmups(
|
||||||
|
server_args.warmups.split(","), _global_state.tokenizer_manager
|
||||||
|
)
|
||||||
|
logger.info("Warmup ended")
|
||||||
|
|
||||||
|
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
|
||||||
|
if warmup_thread is not None:
|
||||||
|
warmup_thread.start()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
# Fast API
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||||
|
|
||||||
|
|
||||||
##### Native API endpoints #####
|
##### Native API endpoints #####
|
||||||
|
|
||||||
|
|
||||||
@@ -123,23 +147,47 @@ async def health() -> Response:
|
|||||||
async def health_generate(request: Request) -> Response:
|
async def health_generate(request: Request) -> Response:
|
||||||
"""Check the health of the inference server by generating one token."""
|
"""Check the health of the inference server by generating one token."""
|
||||||
|
|
||||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
||||||
|
rid = f"HEALTH_CHECK_{time.time()}"
|
||||||
|
|
||||||
if _global_state.tokenizer_manager.is_generation:
|
if _global_state.tokenizer_manager.is_image_gen:
|
||||||
|
raise NotImplementedError()
|
||||||
|
elif _global_state.tokenizer_manager.is_generation:
|
||||||
gri = GenerateReqInput(
|
gri = GenerateReqInput(
|
||||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
rid=rid,
|
||||||
|
input_ids=[0],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
log_metrics=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
gri = EmbeddingReqInput(
|
gri = EmbeddingReqInput(
|
||||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
async def gen():
|
||||||
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
task = asyncio.create_task(gen())
|
||||||
|
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
||||||
|
task.cancel()
|
||||||
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
except Exception as e:
|
|
||||||
logger.exception(e)
|
task.cancel()
|
||||||
|
tic_time = time.strftime("%H:%M:%S", time.localtime(tic))
|
||||||
|
last_receive_time = time.strftime(
|
||||||
|
"%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp)
|
||||||
|
)
|
||||||
|
logger.error(
|
||||||
|
f"Health check failed. Server couldn't get a response from detokenizer for last "
|
||||||
|
f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. "
|
||||||
|
f"last_heartbeat time: {last_receive_time}"
|
||||||
|
)
|
||||||
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||||
return Response(status_code=503)
|
return Response(status_code=503)
|
||||||
|
|
||||||
|
|
||||||
@@ -156,13 +204,21 @@ async def get_model_info():
|
|||||||
|
|
||||||
@app.get("/get_server_info")
|
@app.get("/get_server_info")
|
||||||
async def get_server_info():
|
async def get_server_info():
|
||||||
|
internal_states = await _global_state.tokenizer_manager.get_internal_state()
|
||||||
return {
|
return {
|
||||||
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
||||||
**_global_state.scheduler_info,
|
**_global_state.scheduler_info,
|
||||||
|
**internal_states,
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
|
||||||
|
async def set_internal_state(obj: SetInternalStateReq, request: Request):
|
||||||
|
res = await _global_state.tokenizer_manager.set_internal_state(obj)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
@@ -179,6 +235,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
) + b"\n\n"
|
) + b"\n\n"
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
out = {"error": {"message": str(e)}}
|
out = {"error": {"message": str(e)}}
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
yield b"data: " + orjson.dumps(
|
yield b"data: " + orjson.dumps(
|
||||||
out, option=orjson.OPT_NON_STR_KEYS
|
out, option=orjson.OPT_NON_STR_KEYS
|
||||||
) + b"\n\n"
|
) + b"\n\n"
|
||||||
@@ -236,9 +293,14 @@ async def flush_cache():
|
|||||||
|
|
||||||
|
|
||||||
@app.api_route("/start_profile", methods=["GET", "POST"])
|
@app.api_route("/start_profile", methods=["GET", "POST"])
|
||||||
async def start_profile_async():
|
async def start_profile_async(obj: Optional[ProfileReqInput] = None):
|
||||||
"""Start profiling."""
|
"""Start profiling."""
|
||||||
_global_state.tokenizer_manager.start_profile()
|
if obj is None:
|
||||||
|
obj = ProfileReqInput()
|
||||||
|
|
||||||
|
await _global_state.tokenizer_manager.start_profile(
|
||||||
|
obj.output_dir, obj.num_steps, obj.activities
|
||||||
|
)
|
||||||
return Response(
|
return Response(
|
||||||
content="Start profiling.\n",
|
content="Start profiling.\n",
|
||||||
status_code=200,
|
status_code=200,
|
||||||
@@ -257,11 +319,15 @@ async def stop_profile_async():
|
|||||||
|
|
||||||
@app.post("/update_weights_from_disk")
|
@app.post("/update_weights_from_disk")
|
||||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||||
"""Update the weights from disk in-place without re-launching the server."""
|
"""Update the weights from disk inplace without re-launching the server."""
|
||||||
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
|
success, message, num_paused_requests = (
|
||||||
obj, request
|
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
|
||||||
)
|
)
|
||||||
content = {"success": success, "message": message}
|
content = {
|
||||||
|
"success": success,
|
||||||
|
"message": message,
|
||||||
|
"num_paused_requests": num_paused_requests,
|
||||||
|
}
|
||||||
if success:
|
if success:
|
||||||
return ORJSONResponse(
|
return ORJSONResponse(
|
||||||
content,
|
content,
|
||||||
@@ -323,7 +389,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
|||||||
async def release_memory_occupation(
|
async def release_memory_occupation(
|
||||||
obj: ReleaseMemoryOccupationReqInput, request: Request
|
obj: ReleaseMemoryOccupationReqInput, request: Request
|
||||||
):
|
):
|
||||||
"""Release GPU occupation temporarily"""
|
"""Release GPU memory occupation temporarily."""
|
||||||
try:
|
try:
|
||||||
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
|
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -334,7 +400,7 @@ async def release_memory_occupation(
|
|||||||
async def resume_memory_occupation(
|
async def resume_memory_occupation(
|
||||||
obj: ResumeMemoryOccupationReqInput, request: Request
|
obj: ResumeMemoryOccupationReqInput, request: Request
|
||||||
):
|
):
|
||||||
"""Resume GPU occupation"""
|
"""Resume GPU memory occupation."""
|
||||||
try:
|
try:
|
||||||
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
|
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -357,7 +423,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
|||||||
|
|
||||||
@app.api_route("/close_session", methods=["GET", "POST"])
|
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||||
async def close_session(obj: CloseSessionReqInput, request: Request):
|
async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||||
"""Close the session"""
|
"""Close the session."""
|
||||||
try:
|
try:
|
||||||
await _global_state.tokenizer_manager.close_session(obj, request)
|
await _global_state.tokenizer_manager.close_session(obj, request)
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
@@ -367,7 +433,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|||||||
|
|
||||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||||
"""Close the session"""
|
"""Configure the request logging options."""
|
||||||
_global_state.tokenizer_manager.configure_logging(obj)
|
_global_state.tokenizer_manager.configure_logging(obj)
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
@@ -511,6 +577,7 @@ def _create_error_response(e):
|
|||||||
def launch_server(
|
def launch_server(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
|
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
|
||||||
|
launch_callback: Optional[Callable[[], None]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Launch SRT (SGLang Runtime) Server.
|
Launch SRT (SGLang Runtime) Server.
|
||||||
@@ -544,21 +611,23 @@ def launch_server(
|
|||||||
add_prometheus_middleware(app)
|
add_prometheus_middleware(app)
|
||||||
enable_func_timer()
|
enable_func_timer()
|
||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request - we will create the thread launch it
|
||||||
t = threading.Thread(
|
# in the lifespan after all other warmups have fired.
|
||||||
|
warmup_thread = threading.Thread(
|
||||||
target=_wait_and_warmup,
|
target=_wait_and_warmup,
|
||||||
args=(
|
args=(
|
||||||
server_args,
|
server_args,
|
||||||
pipe_finish_writer,
|
pipe_finish_writer,
|
||||||
_global_state.tokenizer_manager.image_token_id,
|
_global_state.tokenizer_manager.image_token_id,
|
||||||
|
launch_callback,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
t.start()
|
app.warmup_thread = warmup_thread
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Update logging configs
|
# Update logging configs
|
||||||
set_uvicorn_logging_configs()
|
set_uvicorn_logging_configs()
|
||||||
|
app.server_args = server_args
|
||||||
# Listen for HTTP requests
|
# Listen for HTTP requests
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
@@ -569,10 +638,15 @@ def launch_server(
|
|||||||
loop="uvloop",
|
loop="uvloop",
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
t.join()
|
warmup_thread.join()
|
||||||
|
|
||||||
|
|
||||||
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
def _wait_and_warmup(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||||
|
image_token_text: str,
|
||||||
|
launch_callback: Optional[Callable[[], None]] = None,
|
||||||
|
):
|
||||||
headers = {}
|
headers = {}
|
||||||
url = server_args.url()
|
url = server_args.url()
|
||||||
if server_args.api_key:
|
if server_args.api_key:
|
||||||
@@ -614,8 +688,16 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
|||||||
else:
|
else:
|
||||||
json_data["text"] = "The capital city of France is"
|
json_data["text"] = "The capital city of France is"
|
||||||
|
|
||||||
|
# Debug dumping
|
||||||
|
if server_args.debug_tensor_dump_input_file:
|
||||||
|
json_data.pop("text", None)
|
||||||
|
json_data["input_ids"] = np.load(
|
||||||
|
server_args.debug_tensor_dump_input_file
|
||||||
|
).tolist()
|
||||||
|
json_data["sampling_params"]["max_new_tokens"] = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _ in range(server_args.dp_size):
|
for i in range(server_args.dp_size):
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
url + request_name,
|
url + request_name,
|
||||||
json=json_data,
|
json=json_data,
|
||||||
@@ -640,3 +722,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
|||||||
|
|
||||||
if server_args.delete_ckpt_after_loading:
|
if server_args.delete_ckpt_after_loading:
|
||||||
delete_directory(server_args.model_path)
|
delete_directory(server_args.model_path)
|
||||||
|
|
||||||
|
if server_args.debug_tensor_dump_input_file:
|
||||||
|
kill_process_tree(os.getpid())
|
||||||
|
|
||||||
|
if launch_callback is not None:
|
||||||
|
launch_callback()
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ class VerlEngine:
|
|||||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||||
|
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
||||||
lora_path: Optional[List[Optional[str]]] = None,
|
lora_path: Optional[List[Optional[str]]] = None,
|
||||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
@@ -76,6 +77,7 @@ class VerlEngine:
|
|||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
logprob_start_len=logprob_start_len,
|
logprob_start_len=logprob_start_len,
|
||||||
top_logprobs_num=top_logprobs_num,
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
token_ids_logprob=token_ids_logprob,
|
||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
custom_logit_processor=custom_logit_processor,
|
custom_logit_processor=custom_logit_processor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
|
||||||
|
|
||||||
class AttentionBackend(ABC):
|
class AttentionBackend(ABC):
|
||||||
@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -64,7 +64,14 @@ class AttentionBackend(ABC):
|
|||||||
):
|
):
|
||||||
"""Run forward on an attention layer."""
|
"""Run forward on an attention layer."""
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
|
return self.forward_decode(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
layer,
|
||||||
|
forward_batch,
|
||||||
|
save_kv_cache=save_kv_cache,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_extend(
|
return self.forward_extend(
|
||||||
q,
|
q,
|
||||||
@@ -72,7 +79,7 @@ class AttentionBackend(ABC):
|
|||||||
v,
|
v,
|
||||||
layer,
|
layer,
|
||||||
forward_batch,
|
forward_batch,
|
||||||
save_kv_cache,
|
save_kv_cache=save_kv_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
model_runner: ModelRunner,
|
model_runner: ModelRunner,
|
||||||
skip_prefill: bool = False,
|
skip_prefill: bool = False,
|
||||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||||
|
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -125,9 +126,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
assert self.num_wrappers == 1
|
assert self.num_wrappers == 1
|
||||||
self.kv_indptr = [kv_indptr_buf]
|
self.kv_indptr = [kv_indptr_buf]
|
||||||
|
|
||||||
|
if kv_last_page_len_buf is None:
|
||||||
self.kv_last_page_len = torch.ones(
|
self.kv_last_page_len = torch.ones(
|
||||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert self.num_wrappers == 1
|
||||||
|
self.kv_last_page_len = kv_last_page_len_buf
|
||||||
|
|
||||||
self.qo_indptr = [
|
self.qo_indptr = [
|
||||||
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
||||||
for _ in range(self.num_wrappers)
|
for _ in range(self.num_wrappers)
|
||||||
@@ -922,6 +928,9 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
|
self.kv_last_page_len = torch.ones(
|
||||||
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
self.attn_backends = []
|
self.attn_backends = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
@@ -929,6 +938,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
model_runner,
|
model_runner,
|
||||||
skip_prefill=True,
|
skip_prefill=True,
|
||||||
kv_indptr_buf=self.kv_indptr[i],
|
kv_indptr_buf=self.kv_indptr[i],
|
||||||
|
kv_last_page_len_buf=self.kv_last_page_len,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.max_context_len = self.attn_backends[0].max_context_len
|
self.max_context_len = self.attn_backends[0].max_context_len
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -15,7 +15,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
|
||||||
|
|
||||||
class TritonAttnBackend(AttentionBackend):
|
class TritonAttnBackend(AttentionBackend):
|
||||||
@@ -232,7 +232,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
assert encoder_lens is None, "Not supported"
|
assert encoder_lens is None, "Not supported"
|
||||||
|
|
||||||
@@ -310,7 +310,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
# NOTE: encoder_lens expected to be zeros or None
|
# NOTE: encoder_lens expected to be zeros or None
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
|
|||||||
@@ -1,6 +1,21 @@
|
|||||||
import torch
|
from __future__ import annotations
|
||||||
|
|
||||||
from sglang.srt.distributed import GroupCoordinator, get_tp_group
|
import functools
|
||||||
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.distributed import (
|
||||||
|
GroupCoordinator,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
get_tp_group,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
_ATTN_TP_GROUP = None
|
_ATTN_TP_GROUP = None
|
||||||
_ATTN_TP_RANK = None
|
_ATTN_TP_RANK = None
|
||||||
@@ -69,3 +84,129 @@ def get_attention_dp_rank():
|
|||||||
def get_attention_dp_size():
|
def get_attention_dp_size():
|
||||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
assert _DP_SIZE is not None, "dp attention not initialized!"
|
||||||
return _DP_SIZE
|
return _DP_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||||
|
dp_rank = get_attention_dp_rank()
|
||||||
|
|
||||||
|
if forward_batch.dp_local_start_pos is None:
|
||||||
|
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
||||||
|
if dp_rank == 0:
|
||||||
|
local_start_pos = torch.zeros_like(cumtokens[0])
|
||||||
|
else:
|
||||||
|
local_start_pos = cumtokens[dp_rank - 1]
|
||||||
|
local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]
|
||||||
|
|
||||||
|
forward_batch.dp_local_start_pos = local_start_pos
|
||||||
|
forward_batch.dp_local_num_tokens = local_num_tokens
|
||||||
|
|
||||||
|
return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def memcpy_triton_kernel(
|
||||||
|
dst_ptr,
|
||||||
|
src_ptr,
|
||||||
|
offset_ptr,
|
||||||
|
sz_ptr,
|
||||||
|
offset_src,
|
||||||
|
chunk_size, # multiplied for offset and sz
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0).to(tl.int64)
|
||||||
|
offset = tl.load(offset_ptr).to(tl.int64) * chunk_size
|
||||||
|
sz = tl.load(sz_ptr).to(tl.int64) * chunk_size
|
||||||
|
|
||||||
|
start_index = pid * BLOCK_SIZE
|
||||||
|
offs = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = start_index + offs < sz
|
||||||
|
|
||||||
|
if offset_src:
|
||||||
|
data = tl.load(src_ptr + offset + start_index + offs, mask=mask)
|
||||||
|
tl.store(dst_ptr + start_index + offs, data, mask=mask)
|
||||||
|
else:
|
||||||
|
data = tl.load(src_ptr + start_index + offs, mask=mask)
|
||||||
|
tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def prod(x):
|
||||||
|
return functools.reduce(lambda a, b: a * b, x, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
||||||
|
max_size = min(src.numel(), dst.numel())
|
||||||
|
assert dim == 0, "dim != 0 unsupported"
|
||||||
|
assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape"
|
||||||
|
chunk_size = prod(src.shape[1:])
|
||||||
|
BLOCK_SIZE = 8192
|
||||||
|
grid = (triton.cdiv(max_size, BLOCK_SIZE),)
|
||||||
|
|
||||||
|
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
def dp_gather(
|
||||||
|
global_tokens: torch.Tensor,
|
||||||
|
local_tokens: torch.Tensor,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
layer_id: Union[str, int],
|
||||||
|
):
|
||||||
|
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||||
|
|
||||||
|
global_tokens.fill_(0)
|
||||||
|
assert local_tokens.is_contiguous()
|
||||||
|
assert global_tokens.is_contiguous()
|
||||||
|
if local_tokens.shape[0] > 0 and (
|
||||||
|
layer_id != "embedding" or get_attention_tp_rank() == 0
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
|
||||||
|
), "aliasing between global_tokens and local_tokens not allowed"
|
||||||
|
memcpy_triton(
|
||||||
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
|
||||||
|
NUM_GPUS_PER_NODE = 8
|
||||||
|
if (
|
||||||
|
not local_tokens.dtype.is_floating_point
|
||||||
|
and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE
|
||||||
|
):
|
||||||
|
torch.ops.sglang.inplace_all_reduce(
|
||||||
|
global_tokens, group_name=get_tp_group().unique_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def dp_scatter(
|
||||||
|
local_tokens: torch.Tensor, # output
|
||||||
|
global_tokens: torch.Tensor, # input
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
):
|
||||||
|
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
||||||
|
# since local_tokens may be padded for cuda graph
|
||||||
|
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||||
|
local_tokens.fill_(0)
|
||||||
|
assert local_tokens.is_contiguous()
|
||||||
|
assert global_tokens.is_contiguous()
|
||||||
|
if local_tokens.shape[0] > 0:
|
||||||
|
assert (
|
||||||
|
local_tokens.untyped_storage().data_ptr()
|
||||||
|
!= global_tokens.untyped_storage().data_ptr()
|
||||||
|
), "aliasing between local_tokens and global_tokens not allowed"
|
||||||
|
memcpy_triton(
|
||||||
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
|
||||||
|
def do_logits_dp_scatter(logits: torch.Tensor):
|
||||||
|
local_logits = torch.empty(
|
||||||
|
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
|
||||||
|
dtype=logits.dtype,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
dp_scatter(local_logits, logits, forward_batch)
|
||||||
|
return local_logits
|
||||||
|
|
||||||
|
return do_logits_dp_scatter
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class RMSNorm(CustomOp):
|
|||||||
|
|
||||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
x = x.to(orig_dtype) * self.weight
|
x = (x * self.weight).to(orig_dtype)
|
||||||
if residual is None:
|
if residual is None:
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -426,13 +426,14 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
from sglang.srt.layers.parameter import _ColumnvLLMParameter
|
from sglang.srt.layers.parameter import _ColumnvLLMParameter
|
||||||
|
|
||||||
if isinstance(param, _ColumnvLLMParameter):
|
if isinstance(param, _ColumnvLLMParameter):
|
||||||
# FIXME: why would we need this special case?
|
|
||||||
param.load_column_parallel_weight(
|
param.load_column_parallel_weight(
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
use_presharded_weights=self.use_presharded_weights,
|
use_presharded_weights=self.use_presharded_weights,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# FIXME: This branch is needed to load deepseek v3 awq.
|
||||||
|
# However, we should fix this and avoid the branching here.
|
||||||
param.load_column_parallel_weight(loaded_weight)
|
param.load_column_parallel_weight(loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
|
|||||||
@@ -26,12 +26,19 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.dp_attention import (
|
||||||
|
dp_gather,
|
||||||
|
dp_scatter,
|
||||||
|
get_attention_dp_rank,
|
||||||
|
get_attention_dp_size,
|
||||||
|
)
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils import dump_to_file
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -51,6 +58,9 @@ class LogitsProcessorOutput:
|
|||||||
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
||||||
next_token_top_logprobs_val: Optional[List] = None
|
next_token_top_logprobs_val: Optional[List] = None
|
||||||
next_token_top_logprobs_idx: Optional[List] = None
|
next_token_top_logprobs_idx: Optional[List] = None
|
||||||
|
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
|
||||||
|
next_token_token_ids_logprobs_val: Optional[List] = None
|
||||||
|
next_token_token_ids_logprobs_idx: Optional[List] = None
|
||||||
|
|
||||||
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||||
# The logprobs of input tokens. shape: [#token]
|
# The logprobs of input tokens. shape: [#token]
|
||||||
@@ -58,6 +68,9 @@ class LogitsProcessorOutput:
|
|||||||
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
||||||
input_top_logprobs_val: List = None
|
input_top_logprobs_val: List = None
|
||||||
input_top_logprobs_idx: List = None
|
input_top_logprobs_idx: List = None
|
||||||
|
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
||||||
|
input_token_ids_logprobs_val: Optional[List] = None
|
||||||
|
input_token_ids_logprobs_idx: Optional[List] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -67,43 +80,107 @@ class LogitsMetadata:
|
|||||||
|
|
||||||
extend_return_logprob: bool = False
|
extend_return_logprob: bool = False
|
||||||
extend_return_top_logprob: bool = False
|
extend_return_top_logprob: bool = False
|
||||||
|
extend_token_ids_logprob: bool = False
|
||||||
extend_seq_lens: Optional[torch.Tensor] = None
|
extend_seq_lens: Optional[torch.Tensor] = None
|
||||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||||
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
||||||
top_logprobs_nums: Optional[List[int]] = None
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
|
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||||
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||||
|
|
||||||
|
# logits and logprobs post processing
|
||||||
|
temp_scaled_logprobs: bool = False
|
||||||
|
temperature: torch.Tensor = None
|
||||||
|
top_p_normalized_logprobs: bool = False
|
||||||
|
top_p: torch.Tensor = None
|
||||||
|
|
||||||
|
# DP attention metadata. Not needed when DP attention is not used.
|
||||||
|
# Number of tokens in the request.
|
||||||
|
global_num_tokens_gpu: Optional[torch.Tensor] = None
|
||||||
|
# The start position of local hidden states.
|
||||||
|
dp_local_start_pos: Optional[torch.Tensor] = None
|
||||||
|
dp_local_num_tokens: Optional[torch.Tensor] = None
|
||||||
|
gathered_buffer: Optional[torch.Tensor] = None
|
||||||
|
# Buffer to gather logits from all ranks.
|
||||||
|
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
|
||||||
|
# Number of tokens to sample per DP rank
|
||||||
|
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
|
||||||
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# for padding
|
||||||
|
padded_static_len: int = -1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||||
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
|
if (
|
||||||
extend_return_logprob = True
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and forward_batch.return_logprob
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
):
|
||||||
extend_return_top_logprob = any(
|
extend_return_top_logprob = any(
|
||||||
x > 0 for x in forward_batch.top_logprobs_nums
|
x > 0 for x in forward_batch.top_logprobs_nums
|
||||||
)
|
)
|
||||||
extend_logprob_pruned_lens_cpu = [
|
extend_token_ids_logprob = any(
|
||||||
extend_len - start_len
|
x is not None for x in forward_batch.token_ids_logprobs
|
||||||
|
)
|
||||||
|
extend_return_logprob = False
|
||||||
|
extend_logprob_pruned_lens_cpu = []
|
||||||
for extend_len, start_len in zip(
|
for extend_len, start_len in zip(
|
||||||
forward_batch.extend_seq_lens_cpu,
|
forward_batch.extend_seq_lens_cpu,
|
||||||
forward_batch.extend_logprob_start_lens_cpu,
|
forward_batch.extend_logprob_start_lens_cpu,
|
||||||
)
|
):
|
||||||
]
|
if extend_len - start_len > 0:
|
||||||
|
extend_return_logprob = True
|
||||||
|
extend_logprob_pruned_lens_cpu.append(extend_len - start_len)
|
||||||
else:
|
else:
|
||||||
extend_return_logprob = extend_return_top_logprob = (
|
extend_return_logprob = extend_return_top_logprob = (
|
||||||
extend_logprob_pruned_lens_cpu
|
extend_token_ids_logprob
|
||||||
) = False
|
) = extend_logprob_pruned_lens_cpu = False
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
forward_mode=forward_batch.forward_mode,
|
forward_mode=forward_batch.forward_mode,
|
||||||
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
||||||
extend_return_logprob=extend_return_logprob,
|
extend_return_logprob=extend_return_logprob,
|
||||||
extend_return_top_logprob=extend_return_top_logprob,
|
extend_return_top_logprob=extend_return_top_logprob,
|
||||||
|
extend_token_ids_logprob=extend_token_ids_logprob,
|
||||||
extend_seq_lens=forward_batch.extend_seq_lens,
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
||||||
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||||
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
||||||
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
||||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||||
|
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
||||||
|
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
||||||
|
padded_static_len=forward_batch.padded_static_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
|
||||||
|
if self.global_num_tokens_for_logprob_cpu is None:
|
||||||
|
# we are capturing cuda graph
|
||||||
|
return
|
||||||
|
|
||||||
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
||||||
|
dp_rank = get_attention_dp_rank()
|
||||||
|
if dp_rank == 0:
|
||||||
|
dp_local_start_pos = torch.zeros_like(
|
||||||
|
self.global_num_tokens_for_logprob_gpu[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dp_local_start_pos = cumtokens[dp_rank - 1]
|
||||||
|
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
||||||
|
gathered_buffer = torch.zeros(
|
||||||
|
(
|
||||||
|
sum(self.global_num_tokens_for_logprob_cpu),
|
||||||
|
hidden_states.shape[1],
|
||||||
|
),
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dp_local_start_pos = dp_local_start_pos
|
||||||
|
self.dp_local_num_tokens = dp_local_num_tokens
|
||||||
|
self.gathered_buffer = gathered_buffer
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor(nn.Module):
|
class LogitsProcessor(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -115,6 +192,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.do_tensor_parallel_all_gather = (
|
self.do_tensor_parallel_all_gather = (
|
||||||
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
||||||
)
|
)
|
||||||
|
self.do_tensor_parallel_all_gather_dp_attn = (
|
||||||
|
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
||||||
|
)
|
||||||
self.final_logit_softcapping = getattr(
|
self.final_logit_softcapping = getattr(
|
||||||
self.config, "final_logit_softcapping", None
|
self.config, "final_logit_softcapping", None
|
||||||
)
|
)
|
||||||
@@ -124,6 +204,12 @@ class LogitsProcessor(nn.Module):
|
|||||||
):
|
):
|
||||||
self.final_logit_softcapping = None
|
self.final_logit_softcapping = None
|
||||||
|
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
|
||||||
|
self.debug_tensor_dump_output_folder = global_server_args_dict[
|
||||||
|
"debug_tensor_dump_output_folder"
|
||||||
|
]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -141,30 +227,74 @@ class LogitsProcessor(nn.Module):
|
|||||||
):
|
):
|
||||||
pruned_states = hidden_states
|
pruned_states = hidden_states
|
||||||
sample_indices = None
|
sample_indices = None
|
||||||
|
input_logprob_indices = None
|
||||||
elif (
|
elif (
|
||||||
logits_metadata.forward_mode.is_extend()
|
logits_metadata.forward_mode.is_extend()
|
||||||
and not logits_metadata.extend_return_logprob
|
and not logits_metadata.extend_return_logprob
|
||||||
):
|
):
|
||||||
# Prefill without input logprobs.
|
# Prefill without input logprobs.
|
||||||
|
if logits_metadata.padded_static_len < 0:
|
||||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||||
|
else:
|
||||||
|
# If padding_static length is 5 and extended_seq_lens is [2, 3],
|
||||||
|
# then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p]
|
||||||
|
# and this retrieves t01 and t12, which are the valid last tokens
|
||||||
|
idx = torch.arange(
|
||||||
|
len(logits_metadata.extend_seq_lens),
|
||||||
|
device=logits_metadata.extend_seq_lens.device,
|
||||||
|
)
|
||||||
|
last_index = (
|
||||||
|
idx * logits_metadata.padded_static_len
|
||||||
|
+ logits_metadata.extend_seq_lens
|
||||||
|
- 1
|
||||||
|
)
|
||||||
pruned_states = hidden_states[last_index]
|
pruned_states = hidden_states[last_index]
|
||||||
sample_indices = None
|
sample_indices = None
|
||||||
|
input_logprob_indices = None
|
||||||
else:
|
else:
|
||||||
# Slice the requested tokens to compute logprob
|
# Input logprobs are required.
|
||||||
|
# Find 3 different indices.
|
||||||
|
# 1. pruned_states: hidden states that we want logprobs from.
|
||||||
|
# 2. sample_indices: Indices that have sampled tokens.
|
||||||
|
# 3. input_logprob_indices: Indices that have input logprob tokens.
|
||||||
sample_index_pt = -1
|
sample_index_pt = -1
|
||||||
sample_indices = []
|
sample_indices = []
|
||||||
pt, pruned_states, pruned_input_ids = 0, [], []
|
input_logprob_indices_pt = 0
|
||||||
for start_len, extend_len in zip(
|
input_logprob_indices = []
|
||||||
|
pt, pruned_states = 0, []
|
||||||
|
for extend_logprob_start_len, extend_len in zip(
|
||||||
logits_metadata.extend_logprob_start_lens_cpu,
|
logits_metadata.extend_logprob_start_lens_cpu,
|
||||||
logits_metadata.extend_seq_lens_cpu,
|
logits_metadata.extend_seq_lens_cpu,
|
||||||
):
|
):
|
||||||
|
# It can happen in chunked prefill. We still need to sample 1 token,
|
||||||
|
# But we don't want to include it in input logprob.
|
||||||
|
if extend_len == extend_logprob_start_len:
|
||||||
|
start_len = extend_logprob_start_len - 1
|
||||||
|
else:
|
||||||
|
start_len = extend_logprob_start_len
|
||||||
|
|
||||||
|
# We always need at least 1 token to sample because that's required
|
||||||
|
# by a caller.
|
||||||
|
assert extend_len > start_len
|
||||||
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||||
|
pt += extend_len
|
||||||
sample_index_pt += extend_len - start_len
|
sample_index_pt += extend_len - start_len
|
||||||
sample_indices.append(sample_index_pt)
|
sample_indices.append(sample_index_pt)
|
||||||
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
input_logprob_indices.extend(
|
||||||
pt += extend_len
|
[
|
||||||
|
input_logprob_indices_pt + i
|
||||||
|
for i in range(extend_len - extend_logprob_start_len)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
input_logprob_indices_pt += extend_len - start_len
|
||||||
|
|
||||||
pruned_states = torch.cat(pruned_states)
|
pruned_states = torch.cat(pruned_states)
|
||||||
|
sample_indices = torch.tensor(
|
||||||
|
sample_indices, device=pruned_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
input_logprob_indices = torch.tensor(
|
||||||
|
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
# Compute logits for both input and sampled tokens.
|
# Compute logits for both input and sampled tokens.
|
||||||
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
||||||
@@ -172,28 +302,51 @@ class LogitsProcessor(nn.Module):
|
|||||||
logits[sample_indices] if sample_indices is not None else logits
|
logits[sample_indices] if sample_indices is not None else logits
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if self.debug_tensor_dump_output_folder:
|
||||||
not logits_metadata.extend_return_logprob
|
assert (
|
||||||
or logits_metadata.capture_hidden_mode.need_capture()
|
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
|
||||||
):
|
), "dp attention + sharded lm_head doesn't support full logits"
|
||||||
|
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
||||||
|
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
||||||
|
|
||||||
|
hidden_states_to_store: Optional[torch.Tensor] = None
|
||||||
|
if logits_metadata.capture_hidden_mode.need_capture():
|
||||||
|
if logits_metadata.capture_hidden_mode.is_full():
|
||||||
|
hidden_states_to_store = hidden_states
|
||||||
|
elif logits_metadata.capture_hidden_mode.is_last():
|
||||||
|
# Get the last token hidden states. If sample_indices is None,
|
||||||
|
# pruned states only contain the last tokens already.
|
||||||
|
hidden_states_to_store = (
|
||||||
|
pruned_states[sample_indices] if sample_indices else pruned_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, "Should never reach"
|
||||||
|
|
||||||
|
if not logits_metadata.extend_return_logprob:
|
||||||
# Decode mode or extend mode without return_logprob.
|
# Decode mode or extend mode without return_logprob.
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=sampled_logits,
|
next_token_logits=sampled_logits,
|
||||||
hidden_states=(
|
hidden_states=hidden_states_to_store,
|
||||||
hidden_states
|
|
||||||
if logits_metadata.capture_hidden_mode.is_full()
|
|
||||||
else (
|
|
||||||
pruned_states
|
|
||||||
if logits_metadata.capture_hidden_mode.is_last()
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_logprobs = logits
|
input_logprobs = logits[input_logprob_indices]
|
||||||
del hidden_states, logits
|
del hidden_states, logits
|
||||||
|
|
||||||
# Normalize the logprob w/o temperature, top-p
|
# Normalize the logprob w/o temperature, top-p
|
||||||
|
pruned_lens = torch.tensor(
|
||||||
|
logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||||
|
device=input_logprobs.device,
|
||||||
|
)
|
||||||
|
if logits_metadata.temp_scaled_logprobs:
|
||||||
|
logits_metadata.temperature = torch.repeat_interleave(
|
||||||
|
logits_metadata.temperature.view(-1),
|
||||||
|
pruned_lens,
|
||||||
|
).view(-1, 1)
|
||||||
|
if logits_metadata.top_p_normalized_logprobs:
|
||||||
|
logits_metadata.top_p = torch.repeat_interleave(
|
||||||
|
logits_metadata.top_p,
|
||||||
|
pruned_lens,
|
||||||
|
)
|
||||||
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||||
input_logprobs, logits_metadata
|
input_logprobs, logits_metadata
|
||||||
)
|
)
|
||||||
@@ -207,14 +360,18 @@ class LogitsProcessor(nn.Module):
|
|||||||
else:
|
else:
|
||||||
input_top_logprobs_val = input_top_logprobs_idx = None
|
input_top_logprobs_val = input_top_logprobs_idx = None
|
||||||
|
|
||||||
|
# Get the logprob of given token id
|
||||||
|
if logits_metadata.extend_token_ids_logprob:
|
||||||
|
(
|
||||||
|
input_token_ids_logprobs_val,
|
||||||
|
input_token_ids_logprobs_idx,
|
||||||
|
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
|
||||||
|
else:
|
||||||
|
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
|
||||||
|
|
||||||
input_token_logprobs = input_logprobs[
|
input_token_logprobs = input_logprobs[
|
||||||
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
||||||
torch.cat(
|
logits_metadata.extend_input_logprob_token_ids_gpu,
|
||||||
[
|
|
||||||
torch.cat(pruned_input_ids)[1:],
|
|
||||||
torch.tensor([0], device=input_logprobs.device),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
@@ -222,6 +379,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
input_token_logprobs=input_token_logprobs,
|
input_token_logprobs=input_token_logprobs,
|
||||||
input_top_logprobs_val=input_top_logprobs_val,
|
input_top_logprobs_val=input_top_logprobs_val,
|
||||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||||
|
hidden_states=hidden_states_to_store,
|
||||||
|
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
||||||
|
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_logits(
|
def _get_logits(
|
||||||
@@ -231,10 +391,24 @@ class LogitsProcessor(nn.Module):
|
|||||||
logits_metadata: LogitsMetadata,
|
logits_metadata: LogitsMetadata,
|
||||||
embedding_bias: Optional[torch.Tensor] = None,
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Get logits from hidden_states."""
|
"""Get logits from hidden_states.
|
||||||
|
|
||||||
|
If sampled_logits_only is True, it means hidden_states only contain the
|
||||||
|
last position (e.g., extend without input logprobs). The caller should
|
||||||
|
guarantee the given hidden_states follow this constraint.
|
||||||
|
"""
|
||||||
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||||
|
logits_metadata.compute_dp_attention_metadata(hidden_states)
|
||||||
|
hidden_states, local_hidden_states = (
|
||||||
|
logits_metadata.gathered_buffer,
|
||||||
|
hidden_states.clone(),
|
||||||
|
)
|
||||||
|
dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
|
||||||
|
|
||||||
if hasattr(lm_head, "weight"):
|
if hasattr(lm_head, "weight"):
|
||||||
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
logits = torch.matmul(
|
||||||
|
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# GGUF models
|
# GGUF models
|
||||||
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
||||||
@@ -245,6 +419,17 @@ class LogitsProcessor(nn.Module):
|
|||||||
if self.do_tensor_parallel_all_gather:
|
if self.do_tensor_parallel_all_gather:
|
||||||
logits = tensor_model_parallel_all_gather(logits)
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
|
|
||||||
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||||
|
logits, global_logits = (
|
||||||
|
torch.empty(
|
||||||
|
(local_hidden_states.shape[0], logits.shape[1]),
|
||||||
|
device=logits.device,
|
||||||
|
dtype=logits.dtype,
|
||||||
|
),
|
||||||
|
logits,
|
||||||
|
)
|
||||||
|
dp_scatter(logits, global_logits, logits_metadata)
|
||||||
|
|
||||||
logits = logits[:, : self.config.vocab_size].float()
|
logits = logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
if self.final_logit_softcapping:
|
if self.final_logit_softcapping:
|
||||||
@@ -272,20 +457,65 @@ class LogitsProcessor(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
input_top_logprobs_val.append(
|
input_top_logprobs_val.append(
|
||||||
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
[values[pt + j][:k] for j in range(pruned_len)]
|
||||||
)
|
)
|
||||||
input_top_logprobs_idx.append(
|
input_top_logprobs_idx.append(
|
||||||
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
[indices[pt + j][:k] for j in range(pruned_len)]
|
||||||
)
|
)
|
||||||
pt += pruned_len
|
pt += pruned_len
|
||||||
|
|
||||||
return input_top_logprobs_val, input_top_logprobs_idx
|
return input_top_logprobs_val, input_top_logprobs_idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_token_ids_logprobs(
|
||||||
|
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
||||||
|
):
|
||||||
|
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
||||||
|
pt = 0
|
||||||
|
for token_ids, pruned_len in zip(
|
||||||
|
logits_metadata.token_ids_logprobs,
|
||||||
|
logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||||
|
):
|
||||||
|
if pruned_len <= 0:
|
||||||
|
input_token_ids_logprobs_val.append([])
|
||||||
|
input_token_ids_logprobs_idx.append([])
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_token_ids_logprobs_val.append(
|
||||||
|
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
||||||
|
)
|
||||||
|
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
||||||
|
pt += pruned_len
|
||||||
|
|
||||||
|
return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_temp_top_p_normalized_logprobs(
|
def compute_temp_top_p_normalized_logprobs(
|
||||||
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO: Implement the temp and top-p normalization
|
"""
|
||||||
|
compute logprobs for the output token from the given logits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: logprobs from logits
|
||||||
|
"""
|
||||||
|
# Scale logits if temperature scaling is enabled
|
||||||
|
if logits_metadata.temp_scaled_logprobs:
|
||||||
|
last_logits = last_logits / logits_metadata.temperature
|
||||||
|
|
||||||
|
# Normalize logprobs if top_p normalization is enabled
|
||||||
|
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
|
||||||
|
if (
|
||||||
|
logits_metadata.top_p_normalized_logprobs
|
||||||
|
and (logits_metadata.top_p != 1.0).any()
|
||||||
|
):
|
||||||
|
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
|
||||||
|
|
||||||
|
probs = torch.softmax(last_logits, dim=-1)
|
||||||
|
del last_logits
|
||||||
|
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
|
||||||
|
return torch.log(probs)
|
||||||
|
else:
|
||||||
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel(
|
|||||||
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def tanh(x):
|
||||||
|
return 2 * tl.sigmoid(2 * x) - 1
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def gelu_and_mul_triton_kernel(
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
hidden_size,
|
||||||
|
reorder_topk_ids,
|
||||||
|
scales,
|
||||||
|
start_expert_id,
|
||||||
|
end_expert_id,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
InDtype = gateup_output.dtype.element_ty
|
||||||
|
OutDtype = down_input.dtype.element_ty
|
||||||
|
|
||||||
|
half_hidden_size = hidden_size // 2
|
||||||
|
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
expert_id = tl.load(reorder_topk_ids + pid)
|
||||||
|
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||||
|
gateup_output_ptr = gateup_output + pid * hidden_size
|
||||||
|
gate_output_ptr = gateup_output_ptr
|
||||||
|
up_output_ptr = gateup_output_ptr + half_hidden_size
|
||||||
|
down_input_ptr = down_input + pid * half_hidden_size
|
||||||
|
|
||||||
|
if scales is not None:
|
||||||
|
scale = tl.load(scales + expert_id - start_expert_id)
|
||||||
|
scale = (1 / scale).to(InDtype)
|
||||||
|
else:
|
||||||
|
scale = 1
|
||||||
|
|
||||||
|
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
||||||
|
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offset < half_hidden_size
|
||||||
|
|
||||||
|
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
||||||
|
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
||||||
|
|
||||||
|
# gelu & mul & quantize
|
||||||
|
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
||||||
|
# sqrt(2/pi)
|
||||||
|
kAlpha = 0.7978845608028654
|
||||||
|
gate_output = (
|
||||||
|
0.5
|
||||||
|
* gate_output
|
||||||
|
* (
|
||||||
|
1
|
||||||
|
+ tanh(
|
||||||
|
kAlpha
|
||||||
|
* (
|
||||||
|
gate_output
|
||||||
|
+ 0.044715 * gate_output * gate_output * gate_output
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
gate_output = gate_output.to(InDtype)
|
||||||
|
|
||||||
|
gelu_mul_output = gate_output * up_output * scale
|
||||||
|
gelu_mul_output = gelu_mul_output.to(OutDtype)
|
||||||
|
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def post_reorder_triton_kernel(
|
def post_reorder_triton_kernel(
|
||||||
down_output_ptr,
|
down_output_ptr,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
|
gelu_and_mul_triton_kernel,
|
||||||
grouped_gemm_triton,
|
grouped_gemm_triton,
|
||||||
post_reorder_triton_kernel,
|
post_reorder_triton_kernel,
|
||||||
pre_reorder_triton_kernel,
|
pre_reorder_triton_kernel,
|
||||||
@@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.end_expert_id,
|
self.end_expert_id,
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
|
elif self.activation == "gelu":
|
||||||
|
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
gateup_output.shape[1],
|
||||||
|
reorder_topk_ids,
|
||||||
|
self.w2_input_scale,
|
||||||
|
self.start_expert_id,
|
||||||
|
self.end_expert_id,
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ def fused_moe_forward_native(
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from sglang.srt.utils import (
|
|||||||
is_hip,
|
is_hip,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_hip_flag = is_hip()
|
is_hip_ = is_hip()
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -487,6 +487,7 @@ def invoke_fused_moe_kernel(
|
|||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
no_combine: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
@@ -646,7 +647,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 8,
|
"num_warps": 8,
|
||||||
"num_stages": 2 if is_hip_flag else 4,
|
"num_stages": 2 if is_hip_ else 4,
|
||||||
}
|
}
|
||||||
if M <= E:
|
if M <= E:
|
||||||
config = {
|
config = {
|
||||||
@@ -655,7 +656,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_K": 128,
|
"BLOCK_SIZE_K": 128,
|
||||||
"GROUP_SIZE_M": 1,
|
"GROUP_SIZE_M": 1,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2 if is_hip_flag else 4,
|
"num_stages": 2 if is_hip_ else 4,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
||||||
@@ -665,7 +666,7 @@ def get_default_config(
|
|||||||
"BLOCK_SIZE_K": block_shape[1],
|
"BLOCK_SIZE_K": block_shape[1],
|
||||||
"GROUP_SIZE_M": 32,
|
"GROUP_SIZE_M": 32,
|
||||||
"num_warps": 4,
|
"num_warps": 4,
|
||||||
"num_stages": 2 if is_hip_flag else 3,
|
"num_stages": 2 if is_hip_ else 3,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
config = {
|
config = {
|
||||||
@@ -814,6 +815,7 @@ def outplace_fused_experts(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return fused_experts_impl(
|
return fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -831,6 +833,7 @@ def outplace_fused_experts(
|
|||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
block_shape,
|
block_shape,
|
||||||
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -849,6 +852,7 @@ def outplace_fused_experts_fake(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
@@ -877,8 +881,10 @@ def fused_experts(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
no_combine: bool = False,
|
||||||
):
|
):
|
||||||
if inplace:
|
if inplace:
|
||||||
|
assert not no_combine, "no combine + inplace makes no sense"
|
||||||
torch.ops.sglang.inplace_fused_experts(
|
torch.ops.sglang.inplace_fused_experts(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
@@ -912,6 +918,7 @@ def fused_experts(
|
|||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
block_shape,
|
block_shape,
|
||||||
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -931,6 +938,7 @@ def fused_experts_impl(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
no_combine: bool = False,
|
||||||
):
|
):
|
||||||
padded_size = padding_size
|
padded_size = padding_size
|
||||||
if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None:
|
if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None:
|
||||||
@@ -987,7 +995,14 @@ def fused_experts_impl(
|
|||||||
|
|
||||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||||
|
|
||||||
if inplace:
|
if no_combine:
|
||||||
|
assert not inplace
|
||||||
|
out_hidden_states = torch.empty(
|
||||||
|
(num_tokens, topk_ids.shape[1], w2.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
elif inplace:
|
||||||
out_hidden_states = hidden_states
|
out_hidden_states = hidden_states
|
||||||
else:
|
else:
|
||||||
out_hidden_states = torch.empty_like(hidden_states)
|
out_hidden_states = torch.empty_like(hidden_states)
|
||||||
@@ -1057,7 +1072,11 @@ def fused_experts_impl(
|
|||||||
invoke_fused_moe_kernel(
|
invoke_fused_moe_kernel(
|
||||||
intermediate_cache2,
|
intermediate_cache2,
|
||||||
w2,
|
w2,
|
||||||
intermediate_cache3,
|
(
|
||||||
|
intermediate_cache3
|
||||||
|
if not no_combine and topk_ids.shape[1] != 1
|
||||||
|
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||||
|
),
|
||||||
a2_scale,
|
a2_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
curr_topk_weights,
|
curr_topk_weights,
|
||||||
@@ -1075,16 +1094,16 @@ def fused_experts_impl(
|
|||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_hip_flag:
|
if no_combine:
|
||||||
|
pass
|
||||||
|
elif is_hip_:
|
||||||
ops.moe_sum(
|
ops.moe_sum(
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if topk_ids.shape[1] == 1:
|
if topk_ids.shape[1] == 1:
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
|
pass # we write directly into out_hidden_states
|
||||||
intermediate_cache3[:, 0]
|
|
||||||
)
|
|
||||||
elif topk_ids.shape[1] == 2:
|
elif topk_ids.shape[1] == 2:
|
||||||
torch.add(
|
torch.add(
|
||||||
intermediate_cache3[:, 0],
|
intermediate_cache3[:, 0],
|
||||||
@@ -1122,6 +1141,7 @@ def fused_moe(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@@ -1191,4 +1211,5 @@ def fused_moe(
|
|||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.forward(
|
return self.forward(
|
||||||
x=x,
|
x=x,
|
||||||
@@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
inplace=inplace,
|
||||||
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
@@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
from aiter.fused_moe import fused_experts_ck
|
from aiter.fused_moe import fused_experts_ck
|
||||||
|
|
||||||
assert activation == "silu", f"{activation=} is not supported."
|
assert activation == "silu", f"{activation=} is not supported."
|
||||||
|
assert not no_combine, "unsupported"
|
||||||
|
|
||||||
return fused_experts_ck(
|
return fused_experts_ck(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=inplace and not no_combine,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
@@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
inplace: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return moe_forward_native(
|
return moe_forward_native(
|
||||||
layer,
|
layer,
|
||||||
@@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
reduce_results: Whether to all all_reduce on the output of the layer
|
reduce_results: Whether to all all_reduce on the output of the layer
|
||||||
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
||||||
quant_config: Quantization configure.
|
quant_config: Quantization configure.
|
||||||
|
inplace: suggestion to compute inplace (modify input activation).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
use_presharded_weights: bool = False,
|
use_presharded_weights: bool = False,
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.custom_routing_function = custom_routing_function
|
self.custom_routing_function = custom_routing_function
|
||||||
self.correction_bias = correction_bias
|
self.correction_bias = correction_bias
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
self.use_presharded_weights = use_presharded_weights
|
||||||
|
self.inplace = inplace
|
||||||
|
self.no_combine = no_combine
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
@@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
weight_loader=self.weight_loader,
|
weight_loader=self.weight_loader,
|
||||||
)
|
)
|
||||||
self.use_presharded_weights = use_presharded_weights
|
|
||||||
|
|
||||||
def _load_per_tensor_weight_scale(
|
def _load_per_tensor_weight_scale(
|
||||||
self,
|
self,
|
||||||
@@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
custom_routing_function=self.custom_routing_function,
|
custom_routing_function=self.custom_routing_function,
|
||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
|
inplace=self.inplace,
|
||||||
|
no_combine=self.no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
|
|||||||
@@ -771,6 +771,8 @@ class Fp8MoEMethod:
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
@@ -793,6 +795,7 @@ class Fp8MoEMethod:
|
|||||||
from aiter.fused_moe import fused_experts_ck
|
from aiter.fused_moe import fused_experts_ck
|
||||||
|
|
||||||
assert activation == "silu", f"{activation=} is not supported."
|
assert activation == "silu", f"{activation=} is not supported."
|
||||||
|
assert not no_combine, f"{no_combine=} is not supported."
|
||||||
|
|
||||||
return fused_experts_ck(
|
return fused_experts_ck(
|
||||||
x,
|
x,
|
||||||
@@ -823,7 +826,7 @@ class Fp8MoEMethod:
|
|||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=inplace and not no_combine,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
w1_scale=(
|
w1_scale=(
|
||||||
@@ -839,6 +842,7 @@ class Fp8MoEMethod:
|
|||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
no_combine=no_combine,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -707,7 +707,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
cos = freqs.cos() * self.mscale
|
cos = freqs.cos() * self.mscale
|
||||||
sin = freqs.sin() * self.mscale
|
sin = freqs.sin() * self.mscale
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
print("Cache shape", cache.shape)
|
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -41,7 +41,21 @@ class Sampler(nn.Module):
|
|||||||
sampling_info: SamplingBatchInfo,
|
sampling_info: SamplingBatchInfo,
|
||||||
return_logprob: bool,
|
return_logprob: bool,
|
||||||
top_logprobs_nums: List[int],
|
top_logprobs_nums: List[int],
|
||||||
|
token_ids_logprobs: List[List[int]],
|
||||||
|
batch_next_token_ids: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits_output: The logits from the model forward
|
||||||
|
sampling_info: Metadata for sampling
|
||||||
|
return_logprob: If set, store the output logprob information to
|
||||||
|
logits_output
|
||||||
|
top_logprobs_nums: Number of top lobprobs per sequence in a batch
|
||||||
|
batch_next_token_ids: next token IDs. If set, skip sampling and only
|
||||||
|
compute output logprobs It is used for speculative decoding which
|
||||||
|
performs sampling in draft workers.
|
||||||
|
"""
|
||||||
logits = logits_output.next_token_logits
|
logits = logits_output.next_token_logits
|
||||||
|
|
||||||
# Apply the custom logit processors if registered in the sampling info.
|
# Apply the custom logit processors if registered in the sampling info.
|
||||||
@@ -58,13 +72,15 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
if sampling_info.is_all_greedy:
|
if sampling_info.is_all_greedy:
|
||||||
# Use torch.argmax if all requests use greedy sampling
|
# Use torch.argmax if all requests use greedy sampling
|
||||||
|
if batch_next_token_ids is None:
|
||||||
batch_next_token_ids = torch.argmax(logits, -1)
|
batch_next_token_ids = torch.argmax(logits, -1)
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||||
else:
|
else:
|
||||||
# Post process logits
|
# Post process logits
|
||||||
logits.div_(sampling_info.temperatures)
|
logits.div_(sampling_info.temperatures)
|
||||||
probs = torch.softmax(logits, dim=-1)
|
logits[:] = torch.softmax(logits, dim=-1)
|
||||||
|
probs = logits
|
||||||
del logits
|
del logits
|
||||||
|
|
||||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||||
@@ -78,6 +94,7 @@ class Sampler(nn.Module):
|
|||||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||||
).clamp(min=torch.finfo(probs.dtype).min)
|
).clamp(min=torch.finfo(probs.dtype).min)
|
||||||
|
|
||||||
|
if batch_next_token_ids is None:
|
||||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||||
uniform_samples = torch.rand(
|
uniform_samples = torch.rand(
|
||||||
(max_top_k_round, batch_size), device=probs.device
|
(max_top_k_round, batch_size), device=probs.device
|
||||||
@@ -99,9 +116,12 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
if self.use_nan_detection and not torch.all(success):
|
if self.use_nan_detection and not torch.all(success):
|
||||||
logger.warning("Detected errors during sampling!")
|
logger.warning("Detected errors during sampling!")
|
||||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
batch_next_token_ids = torch.zeros_like(
|
||||||
|
batch_next_token_ids
|
||||||
|
)
|
||||||
|
|
||||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||||
|
if batch_next_token_ids is None:
|
||||||
# A slower fallback implementation with torch native operations.
|
# A slower fallback implementation with torch native operations.
|
||||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
probs,
|
probs,
|
||||||
@@ -110,6 +130,7 @@ class Sampler(nn.Module):
|
|||||||
sampling_info.min_ps,
|
sampling_info.min_ps,
|
||||||
sampling_info.need_min_p_sampling,
|
sampling_info.need_min_p_sampling,
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
# clamp to avoid -inf
|
# clamp to avoid -inf
|
||||||
logprobs = torch.log(
|
logprobs = torch.log(
|
||||||
@@ -128,6 +149,12 @@ class Sampler(nn.Module):
|
|||||||
logits_output.next_token_top_logprobs_idx,
|
logits_output.next_token_top_logprobs_idx,
|
||||||
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
||||||
|
|
||||||
|
if any(x is not None for x in token_ids_logprobs):
|
||||||
|
(
|
||||||
|
logits_output.next_token_token_ids_logprobs_val,
|
||||||
|
logits_output.next_token_token_ids_logprobs_idx,
|
||||||
|
) = get_token_ids_logprobs(logprobs, token_ids_logprobs)
|
||||||
|
|
||||||
logits_output.next_token_logprobs = logprobs[
|
logits_output.next_token_logprobs = logprobs[
|
||||||
torch.arange(len(batch_next_token_ids), device=sampling_info.device),
|
torch.arange(len(batch_next_token_ids), device=sampling_info.device),
|
||||||
batch_next_token_ids,
|
batch_next_token_ids,
|
||||||
@@ -223,6 +250,10 @@ def top_p_normalize_probs_torch(
|
|||||||
|
|
||||||
|
|
||||||
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
||||||
|
assert len(top_logprobs_nums) == logprobs.shape[0], (
|
||||||
|
len(top_logprobs_nums),
|
||||||
|
logprobs.shape[0],
|
||||||
|
)
|
||||||
max_k = max(top_logprobs_nums)
|
max_k = max(top_logprobs_nums)
|
||||||
ret = logprobs.topk(max_k, dim=1)
|
ret = logprobs.topk(max_k, dim=1)
|
||||||
values = ret.values.tolist()
|
values = ret.values.tolist()
|
||||||
@@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
|||||||
output_top_logprobs_val.append(values[i][:k])
|
output_top_logprobs_val.append(values[i][:k])
|
||||||
output_top_logprobs_idx.append(indices[i][:k])
|
output_top_logprobs_idx.append(indices[i][:k])
|
||||||
return output_top_logprobs_val, output_top_logprobs_idx
|
return output_top_logprobs_val, output_top_logprobs_idx
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
|
||||||
|
output_token_ids_logprobs_val = []
|
||||||
|
output_token_ids_logprobs_idx = []
|
||||||
|
for i, token_ids in enumerate(token_ids_logprobs):
|
||||||
|
if token_ids is not None:
|
||||||
|
output_token_ids_logprobs_val.append(logprobs[i, token_ids].tolist())
|
||||||
|
output_token_ids_logprobs_idx.append(token_ids)
|
||||||
|
else:
|
||||||
|
output_token_ids_logprobs_val.append([])
|
||||||
|
output_token_ids_logprobs_idx.append([])
|
||||||
|
|
||||||
|
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|
||||||
|
|||||||
@@ -457,7 +457,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
assert loaded_weight.shape[output_dim] == (
|
assert loaded_weight.shape[output_dim] == (
|
||||||
self.org_vocab_size
|
self.org_vocab_size
|
||||||
// (self.tp_size if self.use_presharded_weights else 1)
|
// (self.tp_size if self.use_presharded_weights else 1)
|
||||||
)
|
), f"{self.org_vocab_size=} {self.use_presharded_weights=} {loaded_weight.shape[output_dim]=}"
|
||||||
|
|
||||||
# Copy the data.
|
# Copy the data.
|
||||||
if not self.use_presharded_weights:
|
if not self.use_presharded_weights:
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
||||||
parser.add_argument("--log-requests", action="store_true")
|
parser.add_argument("--log-requests", action="store_true")
|
||||||
|
parser.add_argument("--log-requests-level", type=int, default=2)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
||||||
)
|
)
|
||||||
@@ -38,7 +39,7 @@ if __name__ == "__main__":
|
|||||||
args.url + "/configure_logging",
|
args.url + "/configure_logging",
|
||||||
json={
|
json={
|
||||||
"log_requests": args.log_requests,
|
"log_requests": args.log_requests,
|
||||||
"log_requests_level": 1, # Log full requests
|
"log_requests_level": args.log_requests_level, # Log full requests
|
||||||
"dump_requests_folder": args.dump_requests_folder,
|
"dump_requests_folder": args.dump_requests_folder,
|
||||||
"dump_requests_threshold": args.dump_requests_threshold,
|
"dump_requests_threshold": args.dump_requests_threshold,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -198,6 +198,8 @@ class DataParallelController:
|
|||||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||||
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
||||||
|
|
||||||
|
print(f"{scheduler_info=}")
|
||||||
|
|
||||||
def round_robin_scheduler(self, req):
|
def round_robin_scheduler(self, req):
|
||||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||||
@@ -220,6 +222,7 @@ class DataParallelController:
|
|||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
logger.info("dispatching")
|
||||||
self.dispatching(recv_req)
|
self.dispatching(recv_req)
|
||||||
else:
|
else:
|
||||||
# Send other control messages to first worker of tp group
|
# Send other control messages to first worker of tp group
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
@@ -27,11 +28,16 @@ import zmq
|
|||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
|
BatchMultimodalDecodeReq,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import configure_logger, get_zmq_socket
|
from sglang.srt.utils import (
|
||||||
|
configure_logger,
|
||||||
|
get_zmq_socket,
|
||||||
|
kill_itself_when_parent_died,
|
||||||
|
)
|
||||||
from sglang.utils import (
|
from sglang.utils import (
|
||||||
TypeBasedDispatcher,
|
TypeBasedDispatcher,
|
||||||
find_printable_text,
|
find_printable_text,
|
||||||
@@ -86,14 +92,23 @@ class DetokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
||||||
|
self.is_dummy = server_args.load_format == "dummy"
|
||||||
|
|
||||||
self._request_dispatcher = TypeBasedDispatcher(
|
self._request_dispatcher = TypeBasedDispatcher(
|
||||||
[
|
[
|
||||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
||||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
||||||
|
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def event_loop(self):
|
||||||
|
"""The event loop that handles requests"""
|
||||||
|
while True:
|
||||||
|
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||||
|
output = self._request_dispatcher(recv_obj)
|
||||||
|
self.send_to_tokenizer.send_pyobj(output)
|
||||||
|
|
||||||
def trim_matched_stop(
|
def trim_matched_stop(
|
||||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||||
):
|
):
|
||||||
@@ -117,14 +132,6 @@ class DetokenizerManager:
|
|||||||
return output[:-1]
|
return output[:-1]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def event_loop(self):
|
|
||||||
"""The event loop that handles requests"""
|
|
||||||
|
|
||||||
while True:
|
|
||||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
|
||||||
output = self._request_dispatcher(recv_obj)
|
|
||||||
self.send_to_tokenizer.send_pyobj(output)
|
|
||||||
|
|
||||||
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
|
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
|
||||||
# If it is embedding model, no detokenization is needed.
|
# If it is embedding model, no detokenization is needed.
|
||||||
return recv_obj
|
return recv_obj
|
||||||
@@ -173,7 +180,6 @@ class DetokenizerManager:
|
|||||||
|
|
||||||
# Incremental decoding
|
# Incremental decoding
|
||||||
output_strs = []
|
output_strs = []
|
||||||
finished_reqs = []
|
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
try:
|
try:
|
||||||
s = self.decode_status[recv_obj.rids[i]]
|
s = self.decode_status[recv_obj.rids[i]]
|
||||||
@@ -196,8 +202,6 @@ class DetokenizerManager:
|
|||||||
new_text = ""
|
new_text = ""
|
||||||
else:
|
else:
|
||||||
new_text = find_printable_text(new_text)
|
new_text = find_printable_text(new_text)
|
||||||
else:
|
|
||||||
finished_reqs.append(recv_obj.rids[i])
|
|
||||||
|
|
||||||
output_strs.append(
|
output_strs.append(
|
||||||
self.trim_matched_stop(
|
self.trim_matched_stop(
|
||||||
@@ -207,7 +211,7 @@ class DetokenizerManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
out = BatchStrOut(
|
return BatchStrOut(
|
||||||
rids=recv_obj.rids,
|
rids=recv_obj.rids,
|
||||||
finished_reasons=recv_obj.finished_reasons,
|
finished_reasons=recv_obj.finished_reasons,
|
||||||
output_strs=output_strs,
|
output_strs=output_strs,
|
||||||
@@ -223,14 +227,15 @@ class DetokenizerManager:
|
|||||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||||
|
input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val,
|
||||||
|
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
|
||||||
|
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
||||||
|
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
||||||
output_hidden_states=recv_obj.output_hidden_states,
|
output_hidden_states=recv_obj.output_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove decodestatus for completed requests
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||||
for rid in finished_reqs:
|
raise NotImplementedError()
|
||||||
self.decode_status.pop(rid)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class LimitedCapacityDict(OrderedDict):
|
class LimitedCapacityDict(OrderedDict):
|
||||||
@@ -250,6 +255,7 @@ def run_detokenizer_process(
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
):
|
):
|
||||||
|
kill_itself_when_parent_died()
|
||||||
setproctitle.setproctitle("sglang::detokenizer")
|
setproctitle.setproctitle("sglang::detokenizer")
|
||||||
configure_logger(server_args)
|
configure_logger(server_args)
|
||||||
parent_process = psutil.Process().parent()
|
parent_process = psutil.Process().parent()
|
||||||
|
|||||||
@@ -16,10 +16,11 @@ The definition of objects transfered between different
|
|||||||
processes (TokenizerManager, DetokenizerManager, Controller).
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -55,6 +56,8 @@ class GenerateReqInput:
|
|||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
# If return logprobs, the number of top logprobs to return at each position.
|
# If return logprobs, the number of top logprobs to return at each position.
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
|
# If return logprobs, the token ids to return logprob for.
|
||||||
|
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
# Whether to detokenize tokens in text in the returned logprobs.
|
# Whether to detokenize tokens in text in the returned logprobs.
|
||||||
return_text_in_logprobs: bool = False
|
return_text_in_logprobs: bool = False
|
||||||
# Whether to stream output.
|
# Whether to stream output.
|
||||||
@@ -146,6 +149,8 @@ class GenerateReqInput:
|
|||||||
self.logprob_start_len = -1
|
self.logprob_start_len = -1
|
||||||
if self.top_logprobs_num is None:
|
if self.top_logprobs_num is None:
|
||||||
self.top_logprobs_num = 0
|
self.top_logprobs_num = 0
|
||||||
|
if not self.token_ids_logprob: # covers both None and []
|
||||||
|
self.token_ids_logprob = None
|
||||||
else:
|
else:
|
||||||
if self.parallel_sample_num == 1:
|
if self.parallel_sample_num == 1:
|
||||||
num = self.batch_size
|
num = self.batch_size
|
||||||
@@ -191,6 +196,17 @@ class GenerateReqInput:
|
|||||||
else:
|
else:
|
||||||
assert self.parallel_sample_num == 1
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
|
if not self.token_ids_logprob: # covers both None and []
|
||||||
|
self.token_ids_logprob = [None] * num
|
||||||
|
elif not isinstance(self.token_ids_logprob, list):
|
||||||
|
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
||||||
|
elif not isinstance(self.token_ids_logprob[0], list):
|
||||||
|
self.token_ids_logprob = [
|
||||||
|
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
if self.custom_logit_processor is None:
|
if self.custom_logit_processor is None:
|
||||||
self.custom_logit_processor = [None] * num
|
self.custom_logit_processor = [None] * num
|
||||||
elif not isinstance(self.custom_logit_processor, list):
|
elif not isinstance(self.custom_logit_processor, list):
|
||||||
@@ -198,6 +214,12 @@ class GenerateReqInput:
|
|||||||
else:
|
else:
|
||||||
assert self.parallel_sample_num == 1
|
assert self.parallel_sample_num == 1
|
||||||
|
|
||||||
|
# Other checks
|
||||||
|
if self.session_params is not None:
|
||||||
|
assert isinstance(self.session_params, dict) or isinstance(
|
||||||
|
self.session_params[0], dict
|
||||||
|
)
|
||||||
|
|
||||||
def regenerate_rid(self):
|
def regenerate_rid(self):
|
||||||
self.rid = uuid.uuid4().hex
|
self.rid = uuid.uuid4().hex
|
||||||
return self.rid
|
return self.rid
|
||||||
@@ -212,6 +234,7 @@ class GenerateReqInput:
|
|||||||
return_logprob=self.return_logprob[i],
|
return_logprob=self.return_logprob[i],
|
||||||
logprob_start_len=self.logprob_start_len[i],
|
logprob_start_len=self.logprob_start_len[i],
|
||||||
top_logprobs_num=self.top_logprobs_num[i],
|
top_logprobs_num=self.top_logprobs_num[i],
|
||||||
|
token_ids_logprob=self.token_ids_logprob[i],
|
||||||
return_text_in_logprobs=self.return_text_in_logprobs,
|
return_text_in_logprobs=self.return_text_in_logprobs,
|
||||||
stream=self.stream,
|
stream=self.stream,
|
||||||
log_metrics=self.log_metrics,
|
log_metrics=self.log_metrics,
|
||||||
@@ -244,6 +267,8 @@ class TokenizedGenerateReqInput:
|
|||||||
logprob_start_len: int
|
logprob_start_len: int
|
||||||
# If return logprobs, the number of top logprobs to return at each position.
|
# If return logprobs, the number of top logprobs to return at each position.
|
||||||
top_logprobs_num: int
|
top_logprobs_num: int
|
||||||
|
# If return logprobs, the token id to return logprob for
|
||||||
|
token_ids_logprob: List[int]
|
||||||
# Whether to stream output
|
# Whether to stream output
|
||||||
stream: bool
|
stream: bool
|
||||||
|
|
||||||
@@ -378,10 +403,21 @@ class BatchTokenIDOut:
|
|||||||
input_top_logprobs_idx: List[List]
|
input_top_logprobs_idx: List[List]
|
||||||
output_top_logprobs_val: List[List]
|
output_top_logprobs_val: List[List]
|
||||||
output_top_logprobs_idx: List[List]
|
output_top_logprobs_idx: List[List]
|
||||||
|
input_token_ids_logprobs_val: List[List]
|
||||||
|
input_token_ids_logprobs_idx: List[List]
|
||||||
|
output_token_ids_logprobs_val: List[List]
|
||||||
|
output_token_ids_logprobs_idx: List[List]
|
||||||
|
|
||||||
|
# Hidden states
|
||||||
output_hidden_states: List[List[float]]
|
output_hidden_states: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchMultimodalDecodeReq:
|
||||||
|
# The request id
|
||||||
|
rids: List[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchStrOut:
|
class BatchStrOut:
|
||||||
# The request id
|
# The request id
|
||||||
@@ -406,10 +442,21 @@ class BatchStrOut:
|
|||||||
input_top_logprobs_idx: List[List]
|
input_top_logprobs_idx: List[List]
|
||||||
output_top_logprobs_val: List[List]
|
output_top_logprobs_val: List[List]
|
||||||
output_top_logprobs_idx: List[List]
|
output_top_logprobs_idx: List[List]
|
||||||
|
input_token_ids_logprobs_val: List[List]
|
||||||
|
input_token_ids_logprobs_idx: List[List]
|
||||||
|
output_token_ids_logprobs_val: List[List]
|
||||||
|
output_token_ids_logprobs_idx: List[List]
|
||||||
|
|
||||||
|
# Hidden states
|
||||||
output_hidden_states: List[List[float]]
|
output_hidden_states: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchMultimodalOut:
|
||||||
|
# The request id
|
||||||
|
rids: List[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchEmbeddingOut:
|
class BatchEmbeddingOut:
|
||||||
# The request id
|
# The request id
|
||||||
@@ -439,6 +486,8 @@ class UpdateWeightFromDiskReqInput:
|
|||||||
class UpdateWeightFromDiskReqOutput:
|
class UpdateWeightFromDiskReqOutput:
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
# Number of paused requests during weight sync.
|
||||||
|
num_paused_requests: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -526,11 +575,57 @@ class AbortReq:
|
|||||||
rid: str
|
rid: str
|
||||||
|
|
||||||
|
|
||||||
class ProfileReq(Enum):
|
@dataclass
|
||||||
|
class GetInternalStateReq:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GetInternalStateReqOutput:
|
||||||
|
internal_state: Dict[Any, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SetInternalStateReq:
|
||||||
|
server_args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SetInternalStateReqOutput:
|
||||||
|
updated: bool
|
||||||
|
server_args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProfileReqInput:
|
||||||
|
# The output directory
|
||||||
|
output_dir: Optional[str] = None
|
||||||
|
# If set, it profile as many as this number of steps.
|
||||||
|
# If it is set, profiling is automatically stopped after this step, and
|
||||||
|
# the caller doesn't need to run stop_profile.
|
||||||
|
num_steps: Optional[int] = None
|
||||||
|
activities: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileReqType(Enum):
|
||||||
START_PROFILE = 1
|
START_PROFILE = 1
|
||||||
STOP_PROFILE = 2
|
STOP_PROFILE = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProfileReq:
|
||||||
|
type: ProfileReqType
|
||||||
|
output_dir: Optional[str] = None
|
||||||
|
num_steps: Optional[int] = None
|
||||||
|
activities: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProfileReqOutput:
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConfigureLoggingReq:
|
class ConfigureLoggingReq:
|
||||||
log_requests: Optional[bool] = None
|
log_requests: Optional[bool] = None
|
||||||
@@ -556,6 +651,11 @@ class OpenSessionReqOutput:
|
|||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HealthCheckOutput:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Function:
|
class Function:
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|||||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||||
@@ -50,7 +51,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
|||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
@@ -65,6 +69,8 @@ global_server_args_dict = {
|
|||||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||||
"device": ServerArgs.device,
|
"device": ServerArgs.device,
|
||||||
|
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||||
|
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||||
@@ -230,6 +236,7 @@ class Req:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
return_logprob: bool = False,
|
return_logprob: bool = False,
|
||||||
top_logprobs_num: int = 0,
|
top_logprobs_num: int = 0,
|
||||||
|
token_ids_logprob: List[int] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
||||||
lora_path: Optional[str] = None,
|
lora_path: Optional[str] = None,
|
||||||
@@ -256,17 +263,24 @@ class Req:
|
|||||||
self.input_embeds = input_embeds
|
self.input_embeds = input_embeds
|
||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
|
if isinstance(sampling_params.custom_params, dict):
|
||||||
|
sampling_params = copy.copy(sampling_params)
|
||||||
|
sampling_params.custom_params = sampling_params.custom_params | {
|
||||||
|
"__req__": self
|
||||||
|
}
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
|
|
||||||
self.custom_logit_processor = custom_logit_processor
|
self.custom_logit_processor = custom_logit_processor
|
||||||
self.return_hidden_states = return_hidden_states
|
self.return_hidden_states = return_hidden_states
|
||||||
|
|
||||||
# Memory pool info
|
# Memory pool info
|
||||||
self.req_pool_idx = None
|
self.req_pool_idx: Optional[int] = None
|
||||||
|
|
||||||
# Check finish
|
# Check finish
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.finished_reason = None
|
self.finished_reason = None
|
||||||
|
# If we want to abort the request in the middle of the event loop, set this to true
|
||||||
|
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
||||||
self.to_abort = False
|
self.to_abort = False
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.eos_token_ids = eos_token_ids
|
self.eos_token_ids = eos_token_ids
|
||||||
@@ -289,38 +303,56 @@ class Req:
|
|||||||
self.image_inputs: Optional[ImageInputs] = None
|
self.image_inputs: Optional[ImageInputs] = None
|
||||||
|
|
||||||
# Prefix info
|
# Prefix info
|
||||||
|
# The indices to kv cache for the shared prefix.
|
||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
# Number of tokens to run prefill.
|
||||||
# Updated if chunked.
|
|
||||||
self.extend_input_len = 0
|
self.extend_input_len = 0
|
||||||
|
# The relative logprob_start_len in an extend batch
|
||||||
|
self.extend_logprob_start_len = 0
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
|
|
||||||
# Chunked prefill
|
# Whether or not if it is chunked. It increments whenever
|
||||||
self.is_being_chunked = 0
|
# it is chunked, and decrement whenever chunked request is
|
||||||
|
# processed.
|
||||||
|
self.is_chunked = 0
|
||||||
|
|
||||||
# For retraction
|
# For retraction
|
||||||
self.is_retracted = False
|
self.is_retracted = False
|
||||||
|
|
||||||
# Logprobs (arguments)
|
# Logprobs (arguments)
|
||||||
self.return_logprob = return_logprob
|
self.return_logprob = return_logprob
|
||||||
|
# Start index to compute logprob from.
|
||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
self.top_logprobs_num = top_logprobs_num
|
self.top_logprobs_num = top_logprobs_num
|
||||||
|
self.token_ids_logprob = token_ids_logprob
|
||||||
|
|
||||||
# Logprobs (return values)
|
# Logprobs (return values)
|
||||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||||
self.input_top_logprobs_idx: Optional[List[int]] = None
|
self.input_top_logprobs_idx: Optional[List[int]] = None
|
||||||
|
self.input_token_ids_logprobs_val: Optional[List[float]] = None
|
||||||
|
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||||
|
# Temporary holder to store input_token_logprobs.
|
||||||
|
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
|
||||||
|
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
|
||||||
|
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
|
||||||
|
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
|
||||||
|
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
self.output_token_logprobs_val = []
|
self.output_token_logprobs_val = []
|
||||||
self.output_token_logprobs_idx = []
|
self.output_token_logprobs_idx = []
|
||||||
self.output_top_logprobs_val = []
|
self.output_top_logprobs_val = []
|
||||||
self.output_top_logprobs_idx = []
|
self.output_top_logprobs_idx = []
|
||||||
|
self.output_token_ids_logprobs_val = []
|
||||||
|
self.output_token_ids_logprobs_idx = []
|
||||||
else:
|
else:
|
||||||
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
||||||
self.output_top_logprobs_val
|
self.output_top_logprobs_val
|
||||||
) = self.output_top_logprobs_idx = None
|
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
||||||
|
self.output_token_ids_logprobs_idx
|
||||||
|
) = None
|
||||||
self.hidden_states = []
|
self.hidden_states = []
|
||||||
|
|
||||||
# Logprobs (internal values)
|
# Logprobs (internal values)
|
||||||
@@ -345,6 +377,13 @@ class Req:
|
|||||||
self.spec_verify_ct = 0
|
self.spec_verify_ct = 0
|
||||||
self.lora_path = lora_path
|
self.lora_path = lora_path
|
||||||
|
|
||||||
|
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
||||||
|
self.to_abort_message: str = "Unknown error"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def seqlen(self):
|
||||||
|
return len(self.origin_input_ids) + len(self.output_ids)
|
||||||
|
|
||||||
def extend_image_inputs(self, image_inputs):
|
def extend_image_inputs(self, image_inputs):
|
||||||
if self.image_inputs is None:
|
if self.image_inputs is None:
|
||||||
self.image_inputs = image_inputs
|
self.image_inputs = image_inputs
|
||||||
@@ -422,7 +461,9 @@ class Req:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.to_abort:
|
if self.to_abort:
|
||||||
self.finished_reason = FINISH_ABORT()
|
self.finished_reason = FINISH_ABORT(
|
||||||
|
message=self.to_abort_message,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||||
@@ -517,6 +558,8 @@ class Req:
|
|||||||
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
||||||
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
||||||
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
||||||
|
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
|
||||||
|
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
|
||||||
self.logprob_start_len = prompt_tokens + k
|
self.logprob_start_len = prompt_tokens + k
|
||||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||||
|
|
||||||
@@ -527,16 +570,19 @@ class Req:
|
|||||||
self.last_node = None
|
self.last_node = None
|
||||||
self.extend_input_len = 0
|
self.extend_input_len = 0
|
||||||
self.is_retracted = True
|
self.is_retracted = True
|
||||||
|
self.input_token_logprobs = None
|
||||||
|
self.temp_input_top_logprobs_val = None
|
||||||
|
self.temp_input_top_logprobs_idx = None
|
||||||
|
self.extend_logprob_start_len = 0
|
||||||
|
self.is_chunked = 0
|
||||||
|
self.req_pool_idx = None
|
||||||
|
|
||||||
# For incremental logprobs
|
|
||||||
# TODO: Fix the `logprob_start_len`
|
|
||||||
self.last_update_decode_tokens = 0
|
self.last_update_decode_tokens = 0
|
||||||
self.logprob_start_len = 10**9
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (
|
return (
|
||||||
f"rid(n={self.rid}, "
|
f"Req(rid={self.rid}, "
|
||||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -576,11 +622,13 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
# For DP attention
|
# For DP attention
|
||||||
global_num_tokens: Optional[List[int]] = None
|
global_num_tokens: Optional[List[int]] = None
|
||||||
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||||
can_run_dp_cuda_graph: bool = False
|
can_run_dp_cuda_graph: bool = False
|
||||||
|
|
||||||
# For processing logprobs
|
# For processing logprobs
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: Optional[List[int]] = None
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||||
|
|
||||||
# For extend and mixed chunekd prefill
|
# For extend and mixed chunekd prefill
|
||||||
prefix_lens: List[int] = None
|
prefix_lens: List[int] = None
|
||||||
@@ -588,6 +636,8 @@ class ScheduleBatch:
|
|||||||
extend_num_tokens: int = None
|
extend_num_tokens: int = None
|
||||||
decoding_reqs: List[Req] = None
|
decoding_reqs: List[Req] = None
|
||||||
extend_logprob_start_lens: List[int] = None
|
extend_logprob_start_lens: List[int] = None
|
||||||
|
# It comes empty list if logprob is not required.
|
||||||
|
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For encoder-decoder
|
# For encoder-decoder
|
||||||
encoder_cached: Optional[List[bool]] = None
|
encoder_cached: Optional[List[bool]] = None
|
||||||
@@ -606,7 +656,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[SpecInfo] = None
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
||||||
|
|
||||||
# Enable custom logit processor
|
# Enable custom logit processor
|
||||||
enable_custom_logit_processor: bool = False
|
enable_custom_logit_processor: bool = False
|
||||||
@@ -653,8 +703,10 @@ class ScheduleBatch:
|
|||||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||||
if req_pool_indices is None:
|
if req_pool_indices is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Out of memory. "
|
"alloc_req_slots runs out of memory. "
|
||||||
"Please set a smaller number for `--max-running-requests`."
|
"Please set a smaller number for `--max-running-requests`. "
|
||||||
|
f"{self.req_to_token_pool.available_size()=}, "
|
||||||
|
f"{num_reqs=}, "
|
||||||
)
|
)
|
||||||
return req_pool_indices
|
return req_pool_indices
|
||||||
|
|
||||||
@@ -765,6 +817,7 @@ class ScheduleBatch:
|
|||||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||||
|
|
||||||
input_embeds = []
|
input_embeds = []
|
||||||
|
extend_input_logprob_token_ids = []
|
||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(reqs):
|
for i, req in enumerate(reqs):
|
||||||
@@ -783,22 +836,64 @@ class ScheduleBatch:
|
|||||||
# If req.input_embeds is already a list, append its content directly
|
# If req.input_embeds is already a list, append its content directly
|
||||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||||
|
|
||||||
if req.return_logprob:
|
|
||||||
# Compute the relative logprob_start_len in an extend batch
|
|
||||||
if req.logprob_start_len >= pre_len:
|
|
||||||
extend_logprob_start_len = min(
|
|
||||||
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
|
||||||
)
|
|
||||||
req.extend_logprob_start_len = extend_logprob_start_len
|
|
||||||
|
|
||||||
req.cached_tokens += pre_len - req.already_computed
|
req.cached_tokens += pre_len - req.already_computed
|
||||||
req.already_computed = seq_len
|
req.already_computed = seq_len
|
||||||
req.is_retracted = False
|
req.is_retracted = False
|
||||||
pre_lens.append(pre_len)
|
pre_lens.append(pre_len)
|
||||||
|
# Compute the relative logprob_start_len in an extend batch
|
||||||
|
if req.logprob_start_len >= pre_len:
|
||||||
|
req.extend_logprob_start_len = min(
|
||||||
|
req.logprob_start_len - pre_len,
|
||||||
|
req.extend_input_len,
|
||||||
|
req.seqlen - 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
req.extend_logprob_start_len = 0
|
||||||
|
|
||||||
|
if self.return_logprob:
|
||||||
|
# Find input logprob token ids.
|
||||||
|
# First, find a global index within origin_input_ids and slide it by 1
|
||||||
|
# to compute input logprobs. It is because you need the next token
|
||||||
|
# to compute input logprobs. E.g., (chunk size 2)
|
||||||
|
#
|
||||||
|
# input_logprobs = [1, 2, 3, 4]
|
||||||
|
# fill_ids = [1, 2]
|
||||||
|
# extend_input_logprob_token_id = [2, 3]
|
||||||
|
#
|
||||||
|
# Note that it can also overflow. In this case, we pad it with 0.
|
||||||
|
# input_logprobs = [1, 2, 3, 4]
|
||||||
|
# fill_ids = [3, 4]
|
||||||
|
# extend_input_logprob_token_id = [4, 0]
|
||||||
|
global_start_idx, global_end_idx = (
|
||||||
|
len(req.prefix_indices),
|
||||||
|
len(req.fill_ids),
|
||||||
|
)
|
||||||
|
# Apply logprob_start_len
|
||||||
|
if global_start_idx < req.logprob_start_len:
|
||||||
|
global_start_idx = req.logprob_start_len
|
||||||
|
|
||||||
|
logprob_token_ids = req.origin_input_ids[
|
||||||
|
global_start_idx + 1 : global_end_idx + 1
|
||||||
|
]
|
||||||
|
extend_input_logprob_token_ids.extend(logprob_token_ids)
|
||||||
|
|
||||||
|
# We will need req.extend_input_len - req.extend_logprob_start_len number of
|
||||||
|
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
|
||||||
|
extend_input_logprob_token_ids.extend(
|
||||||
|
[0]
|
||||||
|
* (
|
||||||
|
req.extend_input_len
|
||||||
|
- req.extend_logprob_start_len
|
||||||
|
- len(logprob_token_ids)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.return_logprob:
|
||||||
|
extend_input_logprob_token_ids = torch.tensor(
|
||||||
|
extend_input_logprob_token_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extend_input_logprob_token_ids = None
|
||||||
|
|
||||||
# Set fields
|
# Set fields
|
||||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||||
@@ -821,10 +916,12 @@ class ScheduleBatch:
|
|||||||
self.seq_lens_sum = sum(seq_lens)
|
self.seq_lens_sum = sum(seq_lens)
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||||
|
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
||||||
self.extend_num_tokens = extend_num_tokens
|
self.extend_num_tokens = extend_num_tokens
|
||||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||||
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||||
|
|
||||||
# Write to req_to_token_pool
|
# Write to req_to_token_pool
|
||||||
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
||||||
@@ -860,7 +957,6 @@ class ScheduleBatch:
|
|||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||||
self,
|
self,
|
||||||
self.model_config.vocab_size,
|
self.model_config.vocab_size,
|
||||||
enable_overlap_schedule=self.enable_overlap,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||||
@@ -905,11 +1001,16 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def retract_decode(self):
|
def retract_decode(self, server_args: ServerArgs):
|
||||||
"""Retract the decoding requests when there is not enough memory."""
|
"""Retract the decoding requests when there is not enough memory."""
|
||||||
sorted_indices = [i for i in range(len(self.reqs))]
|
sorted_indices = [i for i in range(len(self.reqs))]
|
||||||
|
|
||||||
# TODO(lsyin): improve retraction policy for radix cache
|
# TODO(lsyin): improve retraction policy for radix cache
|
||||||
|
# For spec decoding, filter_batch API can only filter
|
||||||
|
# requests from the back, so we can only retract from the back.
|
||||||
|
# TODO(sang): Clean up finish path and support better retract
|
||||||
|
# policy.
|
||||||
|
if not server_args.speculative_algorithm:
|
||||||
sorted_indices.sort(
|
sorted_indices.sort(
|
||||||
key=lambda i: (
|
key=lambda i: (
|
||||||
len(self.reqs[i].output_ids),
|
len(self.reqs[i].output_ids),
|
||||||
@@ -918,12 +1019,25 @@ class ScheduleBatch:
|
|||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_required_tokens(num_reqs: int):
|
||||||
|
headroom_for_spec_decode = 0
|
||||||
|
if server_args.speculative_algorithm:
|
||||||
|
headroom_for_spec_decode += (
|
||||||
|
num_reqs
|
||||||
|
* server_args.speculative_eagle_topk
|
||||||
|
* server_args.speculative_num_steps
|
||||||
|
+ num_reqs * server_args.speculative_num_draft_tokens
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||||
|
)
|
||||||
|
|
||||||
retracted_reqs = []
|
retracted_reqs = []
|
||||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||||
first_iter = True
|
first_iter = True
|
||||||
while (
|
while (
|
||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool.available_size()
|
||||||
< len(sorted_indices) * global_config.retract_decode_steps
|
< get_required_tokens(len(sorted_indices))
|
||||||
or first_iter
|
or first_iter
|
||||||
):
|
):
|
||||||
if len(sorted_indices) == 1:
|
if len(sorted_indices) == 1:
|
||||||
@@ -1048,17 +1162,40 @@ class ScheduleBatch:
|
|||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||||
self,
|
self,
|
||||||
self.model_config.vocab_size,
|
self.model_config.vocab_size,
|
||||||
enable_overlap_schedule=self.enable_overlap,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_decode(self):
|
def prepare_for_decode(self):
|
||||||
self.forward_mode = ForwardMode.DECODE
|
self.forward_mode = ForwardMode.DECODE
|
||||||
if self.spec_algorithm.is_eagle():
|
if self.spec_algorithm.is_eagle():
|
||||||
|
# if spec decoding is used, the decode batch is prepared inside
|
||||||
|
# `forward_batch_speculative_generation` after running draft models.
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.sampling_info.penalizer_orchestrator.is_required:
|
||||||
|
if self.enable_overlap:
|
||||||
|
# TODO: this can be slow, optimize this.
|
||||||
|
delayed_output_ids = torch.tensor(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
req.output_ids[-1]
|
||||||
|
if len(req.output_ids)
|
||||||
|
else req.origin_input_ids[-1]
|
||||||
|
)
|
||||||
|
for req in self.reqs
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
|
delayed_output_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
|
self.output_ids.to(torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
self.input_ids = self.output_ids
|
self.input_ids = self.output_ids
|
||||||
self.output_ids = None
|
self.output_ids = None
|
||||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
|
|
||||||
|
|
||||||
# Alloc mem
|
# Alloc mem
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
@@ -1086,14 +1223,15 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
def filter_batch(
|
def filter_batch(
|
||||||
self,
|
self,
|
||||||
being_chunked_req: Optional[Req] = None,
|
chunked_req_to_exclude: Optional[Req] = None,
|
||||||
keep_indices: Optional[List[int]] = None,
|
keep_indices: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
if keep_indices is None:
|
if keep_indices is None:
|
||||||
keep_indices = [
|
keep_indices = [
|
||||||
i
|
i
|
||||||
for i in range(len(self.reqs))
|
for i in range(len(self.reqs))
|
||||||
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
|
if not self.reqs[i].finished()
|
||||||
|
and self.reqs[i] is not chunked_req_to_exclude
|
||||||
]
|
]
|
||||||
|
|
||||||
if keep_indices is None or len(keep_indices) == 0:
|
if keep_indices is None or len(keep_indices) == 0:
|
||||||
@@ -1105,31 +1243,34 @@ class ScheduleBatch:
|
|||||||
# No need to filter
|
# No need to filter
|
||||||
return
|
return
|
||||||
|
|
||||||
|
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
if self.model_config.is_encoder_decoder:
|
if self.model_config.is_encoder_decoder:
|
||||||
self.encoder_lens = self.encoder_lens[keep_indices]
|
self.encoder_lens = self.encoder_lens[keep_indices_device]
|
||||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||||
|
|
||||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||||
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||||
self.device, non_blocking=True
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||||
)
|
|
||||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
|
||||||
self.seq_lens = self.seq_lens[new_indices]
|
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||||
self.output_ids = self.output_ids[new_indices]
|
self.output_ids = self.output_ids[keep_indices_device]
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
||||||
|
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
|
||||||
else:
|
else:
|
||||||
self.top_logprobs_nums = None
|
self.top_logprobs_nums = None
|
||||||
|
self.token_ids_logprobs = None
|
||||||
|
|
||||||
self.has_stream = any(req.stream for req in self.reqs)
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
self.has_grammar = any(req.grammar for req in self.reqs)
|
self.has_grammar = any(req.grammar for req in self.reqs)
|
||||||
|
|
||||||
self.sampling_info.filter_batch(keep_indices, new_indices)
|
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
||||||
if self.spec_info:
|
if self.spec_info:
|
||||||
self.spec_info.filter_batch(new_indices)
|
self.spec_info.filter_batch(keep_indices_device)
|
||||||
|
|
||||||
def merge_batch(self, other: "ScheduleBatch"):
|
def merge_batch(self, other: "ScheduleBatch"):
|
||||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||||
@@ -1152,10 +1293,13 @@ class ScheduleBatch:
|
|||||||
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
||||||
if self.return_logprob and other.return_logprob:
|
if self.return_logprob and other.return_logprob:
|
||||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||||
|
self.token_ids_logprobs.extend(other.token_ids_logprobs)
|
||||||
elif self.return_logprob:
|
elif self.return_logprob:
|
||||||
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
||||||
|
self.token_ids_logprobs.extend([None] * len(other.reqs))
|
||||||
elif other.return_logprob:
|
elif other.return_logprob:
|
||||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||||
|
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
||||||
self.reqs.extend(other.reqs)
|
self.reqs.extend(other.reqs)
|
||||||
|
|
||||||
self.return_logprob |= other.return_logprob
|
self.return_logprob |= other.return_logprob
|
||||||
@@ -1192,7 +1336,9 @@ class ScheduleBatch:
|
|||||||
seq_lens_sum=self.seq_lens_sum,
|
seq_lens_sum=self.seq_lens_sum,
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
top_logprobs_nums=self.top_logprobs_nums,
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
|
token_ids_logprobs=self.token_ids_logprobs,
|
||||||
global_num_tokens=self.global_num_tokens,
|
global_num_tokens=self.global_num_tokens,
|
||||||
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||||
extend_num_tokens=self.extend_num_tokens,
|
extend_num_tokens=self.extend_num_tokens,
|
||||||
extend_seq_lens=extend_seq_lens,
|
extend_seq_lens=extend_seq_lens,
|
||||||
@@ -1219,6 +1365,7 @@ class ScheduleBatch:
|
|||||||
else CaptureHiddenMode.NULL
|
else CaptureHiddenMode.NULL
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@@ -1262,9 +1409,11 @@ class ModelWorkerBatch:
|
|||||||
# For logprob
|
# For logprob
|
||||||
return_logprob: bool
|
return_logprob: bool
|
||||||
top_logprobs_nums: Optional[List[int]]
|
top_logprobs_nums: Optional[List[int]]
|
||||||
|
token_ids_logprobs: Optional[List[List[int]]]
|
||||||
|
|
||||||
# For DP attention
|
# For DP attention
|
||||||
global_num_tokens: Optional[List[int]]
|
global_num_tokens: Optional[List[int]]
|
||||||
|
global_num_tokens_for_logprob: Optional[List[int]]
|
||||||
can_run_dp_cuda_graph: bool
|
can_run_dp_cuda_graph: bool
|
||||||
|
|
||||||
# For extend
|
# For extend
|
||||||
@@ -1272,6 +1421,7 @@ class ModelWorkerBatch:
|
|||||||
extend_seq_lens: Optional[List[int]]
|
extend_seq_lens: Optional[List[int]]
|
||||||
extend_prefix_lens: Optional[List[int]]
|
extend_prefix_lens: Optional[List[int]]
|
||||||
extend_logprob_start_lens: Optional[List[int]]
|
extend_logprob_start_lens: Optional[List[int]]
|
||||||
|
extend_input_logprob_token_ids: Optional[torch.Tensor]
|
||||||
|
|
||||||
# For multimodal
|
# For multimodal
|
||||||
image_inputs: Optional[List[ImageInputs]]
|
image_inputs: Optional[List[ImageInputs]]
|
||||||
@@ -1293,7 +1443,8 @@ class ModelWorkerBatch:
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[SpecInfo] = None
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||||
|
# If set, the output of the batch contains the hidden states of the run.
|
||||||
capture_hidden_mode: CaptureHiddenMode = None
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ class PrefillAdder:
|
|||||||
|
|
||||||
self.req_states = None
|
self.req_states = None
|
||||||
self.can_run_list = []
|
self.can_run_list = []
|
||||||
self.new_being_chunked_req = None
|
self.new_chunked_req = None
|
||||||
self.log_hit_tokens = 0
|
self.log_hit_tokens = 0
|
||||||
self.log_input_tokens = 0
|
self.log_input_tokens = 0
|
||||||
|
|
||||||
@@ -327,7 +327,7 @@ class PrefillAdder:
|
|||||||
self.log_hit_tokens += prefix_len
|
self.log_hit_tokens += prefix_len
|
||||||
self.log_input_tokens += extend_input_len
|
self.log_input_tokens += extend_input_len
|
||||||
|
|
||||||
def add_being_chunked_req(self, req: Req):
|
def add_chunked_req(self, req: Req):
|
||||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||||
@@ -354,7 +354,7 @@ class PrefillAdder:
|
|||||||
finally:
|
finally:
|
||||||
self.tree_cache.dec_lock_ref(last_node)
|
self.tree_cache.dec_lock_ref(last_node)
|
||||||
|
|
||||||
def add_one_req_ignore_eos(self, req: Req):
|
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
|
||||||
def add_req_state(r, insert_sort=False):
|
def add_req_state(r, insert_sort=False):
|
||||||
new_token_ratio = (
|
new_token_ratio = (
|
||||||
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
||||||
@@ -403,6 +403,7 @@ class PrefillAdder:
|
|||||||
self.rem_chunk_tokens is None
|
self.rem_chunk_tokens is None
|
||||||
or req.extend_input_len <= self.rem_chunk_tokens
|
or req.extend_input_len <= self.rem_chunk_tokens
|
||||||
):
|
):
|
||||||
|
# Non-chunked prefill
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self._prefill_one_req(
|
self._prefill_one_req(
|
||||||
0,
|
0,
|
||||||
@@ -418,14 +419,14 @@ class PrefillAdder:
|
|||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[:trunc_len]
|
req.fill_ids = req.fill_ids[:trunc_len]
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self.new_being_chunked_req = req
|
self.new_chunked_req = req
|
||||||
self._prefill_one_req(0, trunc_len, 0)
|
self._prefill_one_req(0, trunc_len, 0)
|
||||||
|
|
||||||
return self.budget_state()
|
return self.budget_state()
|
||||||
|
|
||||||
def add_one_req(self, req: Req):
|
def add_one_req(self, req: Req, has_chunked_req: bool):
|
||||||
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
||||||
return self.add_one_req_ignore_eos(req)
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||||
|
|
||||||
total_tokens = req.extend_input_len + min(
|
total_tokens = req.extend_input_len + min(
|
||||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
||||||
@@ -443,14 +444,7 @@ class PrefillAdder:
|
|||||||
if total_tokens > self.rem_total_tokens:
|
if total_tokens > self.rem_total_tokens:
|
||||||
return AddReqResult.NO_TOKEN
|
return AddReqResult.NO_TOKEN
|
||||||
|
|
||||||
if (
|
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
||||||
self.rem_chunk_tokens is None
|
|
||||||
or input_tokens <= self.rem_chunk_tokens
|
|
||||||
or (
|
|
||||||
req.return_logprob
|
|
||||||
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
|
||||||
)
|
|
||||||
):
|
|
||||||
# Non-chunked prefill
|
# Non-chunked prefill
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
@@ -470,8 +464,9 @@ class PrefillAdder:
|
|||||||
|
|
||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||||
|
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self.new_being_chunked_req = req
|
self.new_chunked_req = req
|
||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -35,12 +35,12 @@ class SessionReqNode:
|
|||||||
for req_node in self.childs:
|
for req_node in self.childs:
|
||||||
req_node.clear(req_dict)
|
req_node.clear(req_dict)
|
||||||
|
|
||||||
if self.req.finished_reason == None:
|
if self.req.finished_reason is None:
|
||||||
self.req.to_abort = True
|
self.req.to_abort = True
|
||||||
del req_dict[self.req.rid]
|
del req_dict[self.req.rid]
|
||||||
|
|
||||||
def abort(self):
|
def abort(self):
|
||||||
if self.req.finished_reason == None:
|
if self.req.finished_reason is None:
|
||||||
self.req.to_abort = True
|
self.req.to_abort = True
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -132,6 +132,10 @@ class Session:
|
|||||||
lora_path=req.lora_path,
|
lora_path=req.lora_path,
|
||||||
session_id=self.session_id,
|
session_id=self.session_id,
|
||||||
custom_logit_processor=req.custom_logit_processor,
|
custom_logit_processor=req.custom_logit_processor,
|
||||||
|
stream=req.stream,
|
||||||
|
return_logprob=req.return_logprob,
|
||||||
|
top_logprobs_num=req.top_logprobs_num,
|
||||||
|
token_ids_logprob=req.token_ids_logprob,
|
||||||
)
|
)
|
||||||
if last_req is not None:
|
if last_req is not None:
|
||||||
new_req.image_inputs = last_req.image_inputs
|
new_req.image_inputs = last_req.image_inputs
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@@ -24,9 +25,21 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections import deque
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Deque,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import uvloop
|
import uvloop
|
||||||
@@ -44,6 +57,7 @@ from sglang.srt.managers.image_processor import (
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
|
BatchMultimodalOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
@@ -51,18 +65,25 @@ from sglang.srt.managers.io_struct import (
|
|||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
|
GetInternalStateReq,
|
||||||
|
GetInternalStateReqOutput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
GetWeightsByNameReqOutput,
|
GetWeightsByNameReqOutput,
|
||||||
|
HealthCheckOutput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
InitWeightsUpdateGroupReqOutput,
|
InitWeightsUpdateGroupReqOutput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
|
ProfileReqOutput,
|
||||||
|
ProfileReqType,
|
||||||
ReleaseMemoryOccupationReqInput,
|
ReleaseMemoryOccupationReqInput,
|
||||||
ReleaseMemoryOccupationReqOutput,
|
ReleaseMemoryOccupationReqOutput,
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqOutput,
|
ResumeMemoryOccupationReqOutput,
|
||||||
SessionParams,
|
SessionParams,
|
||||||
|
SetInternalStateReq,
|
||||||
|
SetInternalStateReqOutput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
@@ -98,7 +119,10 @@ class ReqState:
|
|||||||
|
|
||||||
# For metrics
|
# For metrics
|
||||||
created_time: float
|
created_time: float
|
||||||
first_token_time: Optional[float] = None
|
finished_time: float = 0.0
|
||||||
|
first_token_time: float = 0.0
|
||||||
|
last_time: float = 0.0
|
||||||
|
last_completion_tokens: int = 1
|
||||||
|
|
||||||
# For streaming output
|
# For streaming output
|
||||||
last_output_offset: int = 0
|
last_output_offset: int = 0
|
||||||
@@ -113,11 +137,10 @@ class TokenizerManager:
|
|||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
|
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.enable_metrics = server_args.enable_metrics
|
self.enable_metrics = server_args.enable_metrics
|
||||||
self.log_requests = server_args.log_requests
|
self.log_requests = server_args.log_requests
|
||||||
self.log_requests_level = 0
|
self.log_requests_level = server_args.log_requests_level
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.asyncio.Context(2)
|
context = zmq.asyncio.Context(2)
|
||||||
@@ -143,6 +166,7 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.is_generation = self.model_config.is_generation
|
self.is_generation = self.model_config.is_generation
|
||||||
|
self.is_image_gen = self.model_config.is_image_gen
|
||||||
self.context_len = self.model_config.context_len
|
self.context_len = self.model_config.context_len
|
||||||
self.image_token_id = self.model_config.image_token_id
|
self.image_token_id = self.model_config.image_token_id
|
||||||
|
|
||||||
@@ -178,9 +202,12 @@ class TokenizerManager:
|
|||||||
# Store states
|
# Store states
|
||||||
self.no_create_loop = False
|
self.no_create_loop = False
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
|
self.gracefully_exit = False
|
||||||
|
self.last_receive_tstamp = 0
|
||||||
self.dump_requests_folder = "" # By default do not dump
|
self.dump_requests_folder = "" # By default do not dump
|
||||||
self.dump_requests_threshold = 1000
|
self.dump_requests_threshold = 1000
|
||||||
self.dump_request_list: List[Tuple] = []
|
self.dump_request_list: List[Tuple] = []
|
||||||
|
self.log_request_metadata = self.get_log_request_metadata()
|
||||||
|
|
||||||
# The event to notify the weight sync is finished.
|
# The event to notify the weight sync is finished.
|
||||||
self.model_update_lock = RWLock()
|
self.model_update_lock = RWLock()
|
||||||
@@ -192,8 +219,19 @@ class TokenizerManager:
|
|||||||
# For session info
|
# For session info
|
||||||
self.session_futures = {} # session_id -> asyncio event
|
self.session_futures = {} # session_id -> asyncio event
|
||||||
|
|
||||||
# Others
|
# Set after scheduler is initialized
|
||||||
self.gracefully_exit = False
|
self.max_req_input_len = None
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
if self.enable_metrics:
|
||||||
|
self.metrics_collector = TokenizerMetricsCollector(
|
||||||
|
labels={
|
||||||
|
"model_name": self.server_args.served_model_name,
|
||||||
|
# TODO: Add lora name/path in the future,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Communicators
|
||||||
self.init_weights_update_group_communicator = _Communicator(
|
self.init_weights_update_group_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
@@ -212,22 +250,26 @@ class TokenizerManager:
|
|||||||
self.resume_memory_occupation_communicator = _Communicator(
|
self.resume_memory_occupation_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
# Set after scheduler is initialized
|
self.start_profile_communicator = _Communicator(
|
||||||
self.max_req_input_len = None
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
# Metrics
|
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
||||||
if self.enable_metrics:
|
self.get_internal_state_communicator = _Communicator(
|
||||||
self.metrics_collector = TokenizerMetricsCollector(
|
self.send_to_scheduler, server_args.dp_size
|
||||||
labels={
|
)
|
||||||
"model_name": self.server_args.served_model_name,
|
self.set_internal_state_communicator = _Communicator(
|
||||||
# TODO: Add lora name/path in the future,
|
self.send_to_scheduler, server_args.dp_size
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._result_dispatcher = TypeBasedDispatcher(
|
self._result_dispatcher = TypeBasedDispatcher(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
|
(
|
||||||
|
BatchStrOut,
|
||||||
|
BatchEmbeddingOut,
|
||||||
|
BatchTokenIDOut,
|
||||||
|
BatchMultimodalOut,
|
||||||
|
),
|
||||||
self._handle_batch_output,
|
self._handle_batch_output,
|
||||||
),
|
),
|
||||||
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
||||||
@@ -259,6 +301,19 @@ class TokenizerManager:
|
|||||||
ResumeMemoryOccupationReqOutput,
|
ResumeMemoryOccupationReqOutput,
|
||||||
self.resume_memory_occupation_communicator.handle_recv,
|
self.resume_memory_occupation_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
ProfileReqOutput,
|
||||||
|
self.start_profile_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
GetInternalStateReqOutput,
|
||||||
|
self.get_internal_state_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
SetInternalStateReqOutput,
|
||||||
|
self.set_internal_state_communicator.handle_recv,
|
||||||
|
),
|
||||||
|
(HealthCheckOutput, lambda x: None),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -280,9 +335,9 @@ class TokenizerManager:
|
|||||||
obj.normalize_batch_and_arguments()
|
obj.normalize_batch_and_arguments()
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
max_length, skip_names, _ = self.log_request_metadata
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
|
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self.model_update_lock.reader_lock:
|
async with self.model_update_lock.reader_lock:
|
||||||
@@ -336,6 +391,7 @@ class TokenizerManager:
|
|||||||
return_logprob = obj.return_logprob
|
return_logprob = obj.return_logprob
|
||||||
logprob_start_len = obj.logprob_start_len
|
logprob_start_len = obj.logprob_start_len
|
||||||
top_logprobs_num = obj.top_logprobs_num
|
top_logprobs_num = obj.top_logprobs_num
|
||||||
|
token_ids_logprob = obj.token_ids_logprob
|
||||||
session_params = (
|
session_params = (
|
||||||
SessionParams(**obj.session_params) if obj.session_params else None
|
SessionParams(**obj.session_params) if obj.session_params else None
|
||||||
)
|
)
|
||||||
@@ -378,6 +434,7 @@ class TokenizerManager:
|
|||||||
return_logprob,
|
return_logprob,
|
||||||
logprob_start_len,
|
logprob_start_len,
|
||||||
top_logprobs_num,
|
top_logprobs_num,
|
||||||
|
token_ids_logprob,
|
||||||
obj.stream,
|
obj.stream,
|
||||||
lora_path=obj.lora_path,
|
lora_path=obj.lora_path,
|
||||||
input_embeds=input_embeds,
|
input_embeds=input_embeds,
|
||||||
@@ -401,8 +458,7 @@ class TokenizerManager:
|
|||||||
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||||
created_time: Optional[float] = None,
|
created_time: Optional[float] = None,
|
||||||
):
|
):
|
||||||
event = asyncio.Event()
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
||||||
state = ReqState([], False, event, obj, created_time=created_time)
|
|
||||||
self.rid_to_state[obj.rid] = state
|
self.rid_to_state[obj.rid] = state
|
||||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||||
|
|
||||||
@@ -420,7 +476,10 @@ class TokenizerManager:
|
|||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if request is not None and await request.is_disconnected():
|
if request is not None and await request.is_disconnected():
|
||||||
self.abort_request(obj.rid)
|
self.abort_request(obj.rid)
|
||||||
raise ValueError(f"Abort request {obj.rid}")
|
raise ValueError(
|
||||||
|
"Request is disconnected from the client side. "
|
||||||
|
f"Abort request {obj.rid}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
out = state.out_list[-1]
|
out = state.out_list[-1]
|
||||||
@@ -428,8 +487,11 @@ class TokenizerManager:
|
|||||||
state.out_list = []
|
state.out_list = []
|
||||||
if state.finished:
|
if state.finished:
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
max_length, skip_names, out_skip_names = self.log_request_metadata
|
||||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
|
if self.model_config.is_multimodal_gen:
|
||||||
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
||||||
|
else:
|
||||||
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
del self.rid_to_state[obj.rid]
|
del self.rid_to_state[obj.rid]
|
||||||
|
|
||||||
@@ -452,7 +514,10 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
if request is not None and await request.is_disconnected():
|
if request is not None and await request.is_disconnected():
|
||||||
self.abort_request(obj.rid)
|
self.abort_request(obj.rid)
|
||||||
raise ValueError(f"Abort request {obj.rid}")
|
raise ValueError(
|
||||||
|
"Request is disconnected from the client side. "
|
||||||
|
f"Abort request {obj.rid}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _handle_batch_request(
|
async def _handle_batch_request(
|
||||||
self,
|
self,
|
||||||
@@ -543,12 +608,25 @@ class TokenizerManager:
|
|||||||
req = AbortReq(rid)
|
req = AbortReq(rid)
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
def start_profile(self):
|
async def start_profile(
|
||||||
req = ProfileReq.START_PROFILE
|
self,
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
output_dir: Optional[str] = None,
|
||||||
|
num_steps: Optional[int] = None,
|
||||||
|
activities: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
req = ProfileReq(
|
||||||
|
type=ProfileReqType.START_PROFILE,
|
||||||
|
output_dir=output_dir,
|
||||||
|
num_steps=num_steps,
|
||||||
|
activities=activities,
|
||||||
|
)
|
||||||
|
result = (await self.start_profile_communicator(req))[0]
|
||||||
|
if not result.success:
|
||||||
|
raise RuntimeError(result.message)
|
||||||
|
return result
|
||||||
|
|
||||||
def stop_profile(self):
|
def stop_profile(self):
|
||||||
req = ProfileReq.STOP_PROFILE
|
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
async def update_weights_from_disk(
|
async def update_weights_from_disk(
|
||||||
@@ -581,7 +659,7 @@ class TokenizerManager:
|
|||||||
self.server_args.model_path = obj.model_path
|
self.server_args.model_path = obj.model_path
|
||||||
self.server_args.load_format = obj.load_format
|
self.server_args.load_format = obj.load_format
|
||||||
self.model_path = obj.model_path
|
self.model_path = obj.model_path
|
||||||
return result.success, result.message
|
return result.success, result.message, result.num_paused_requests
|
||||||
else: # self.server_args.dp_size > 1
|
else: # self.server_args.dp_size > 1
|
||||||
self.model_update_tmp = []
|
self.model_update_tmp = []
|
||||||
result = await self.model_update_result
|
result = await self.model_update_result
|
||||||
@@ -593,7 +671,8 @@ class TokenizerManager:
|
|||||||
self.model_path = obj.model_path
|
self.model_path = obj.model_path
|
||||||
all_message = [r.message for r in result]
|
all_message = [r.message for r in result]
|
||||||
all_message = " | ".join(all_message)
|
all_message = " | ".join(all_message)
|
||||||
return all_success, all_message
|
all_paused_requests = [r.num_paused_requests for r in result]
|
||||||
|
return all_success, all_message, all_paused_requests
|
||||||
|
|
||||||
async def init_weights_update_group(
|
async def init_weights_update_group(
|
||||||
self,
|
self,
|
||||||
@@ -688,6 +767,54 @@ class TokenizerManager:
|
|||||||
):
|
):
|
||||||
await self.send_to_scheduler.send_pyobj(obj)
|
await self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
|
||||||
|
async def get_internal_state(self) -> Dict[Any, Any]:
|
||||||
|
req = GetInternalStateReq()
|
||||||
|
res: List[GetInternalStateReqOutput] = (
|
||||||
|
await self.get_internal_state_communicator(req)
|
||||||
|
)
|
||||||
|
return res[0].internal_state
|
||||||
|
|
||||||
|
async def set_internal_state(
|
||||||
|
self, obj: SetInternalStateReq
|
||||||
|
) -> SetInternalStateReqOutput:
|
||||||
|
res: List[SetInternalStateReqOutput] = (
|
||||||
|
await self.set_internal_state_communicator(obj)
|
||||||
|
)
|
||||||
|
return res[0]
|
||||||
|
|
||||||
|
def get_log_request_metadata(self):
|
||||||
|
max_length = None
|
||||||
|
skip_names = None
|
||||||
|
out_skip_names = None
|
||||||
|
if self.log_requests:
|
||||||
|
if self.log_requests_level == 0:
|
||||||
|
max_length = 1 << 30
|
||||||
|
skip_names = set(
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
"input_ids",
|
||||||
|
"input_embeds",
|
||||||
|
"image_data",
|
||||||
|
"audio_data",
|
||||||
|
"lora_path",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out_skip_names = set(
|
||||||
|
[
|
||||||
|
"text",
|
||||||
|
"output_ids",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif self.log_requests_level == 1:
|
||||||
|
max_length = 2048
|
||||||
|
elif self.log_requests_level == 2:
|
||||||
|
max_length = 1 << 30
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
||||||
|
)
|
||||||
|
return max_length, skip_names, out_skip_names
|
||||||
|
|
||||||
def configure_logging(self, obj: ConfigureLoggingReq):
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
||||||
if obj.log_requests is not None:
|
if obj.log_requests is not None:
|
||||||
self.log_requests = obj.log_requests
|
self.log_requests = obj.log_requests
|
||||||
@@ -698,6 +825,7 @@ class TokenizerManager:
|
|||||||
if obj.dump_requests_threshold is not None:
|
if obj.dump_requests_threshold is not None:
|
||||||
self.dump_requests_threshold = obj.dump_requests_threshold
|
self.dump_requests_threshold = obj.dump_requests_threshold
|
||||||
logging.info(f"Config logging: {obj=}")
|
logging.info(f"Config logging: {obj=}")
|
||||||
|
self.log_request_metadata = self.get_log_request_metadata()
|
||||||
|
|
||||||
def create_abort_task(self, obj: GenerateReqInput):
|
def create_abort_task(self, obj: GenerateReqInput):
|
||||||
# Abort the request if the client is disconnected.
|
# Abort the request if the client is disconnected.
|
||||||
@@ -762,15 +890,20 @@ class TokenizerManager:
|
|||||||
while True:
|
while True:
|
||||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||||
self._result_dispatcher(recv_obj)
|
self._result_dispatcher(recv_obj)
|
||||||
|
self.last_receive_tstamp = time.time()
|
||||||
|
|
||||||
def _handle_batch_output(
|
def _handle_batch_output(
|
||||||
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
|
self,
|
||||||
|
recv_obj: Union[
|
||||||
|
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
|
||||||
|
],
|
||||||
):
|
):
|
||||||
for i, rid in enumerate(recv_obj.rids):
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
state = self.rid_to_state.get(rid, None)
|
state = self.rid_to_state.get(rid, None)
|
||||||
if state is None:
|
if state is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Build meta_info and return value
|
||||||
meta_info = {
|
meta_info = {
|
||||||
"id": rid,
|
"id": rid,
|
||||||
"finish_reason": recv_obj.finished_reasons[i],
|
"finish_reason": recv_obj.finished_reasons[i],
|
||||||
@@ -781,14 +914,12 @@ class TokenizerManager:
|
|||||||
self.convert_logprob_style(
|
self.convert_logprob_style(
|
||||||
meta_info,
|
meta_info,
|
||||||
state.obj.top_logprobs_num,
|
state.obj.top_logprobs_num,
|
||||||
|
state.obj.token_ids_logprob,
|
||||||
state.obj.return_text_in_logprobs,
|
state.obj.return_text_in_logprobs,
|
||||||
recv_obj,
|
recv_obj,
|
||||||
i,
|
i,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.server_args.speculative_algorithm:
|
|
||||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
|
||||||
|
|
||||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||||
meta_info.update(
|
meta_info.update(
|
||||||
{
|
{
|
||||||
@@ -806,10 +937,20 @@ class TokenizerManager:
|
|||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
}
|
}
|
||||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||||
|
if self.server_args.stream_output and state.obj.stream:
|
||||||
|
output_token_ids = recv_obj.output_ids[i][
|
||||||
|
state.last_output_offset :
|
||||||
|
]
|
||||||
|
state.last_output_offset = len(recv_obj.output_ids[i])
|
||||||
|
else:
|
||||||
|
output_token_ids = recv_obj.output_ids[i]
|
||||||
|
|
||||||
out_dict = {
|
out_dict = {
|
||||||
"token_ids": recv_obj.output_ids[i],
|
"output_ids": output_token_ids,
|
||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
}
|
}
|
||||||
|
elif isinstance(recv_obj, BatchMultimodalOut):
|
||||||
|
raise NotImplementedError()
|
||||||
else:
|
else:
|
||||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||||
out_dict = {
|
out_dict = {
|
||||||
@@ -817,10 +958,17 @@ class TokenizerManager:
|
|||||||
"meta_info": meta_info,
|
"meta_info": meta_info,
|
||||||
}
|
}
|
||||||
|
|
||||||
state.out_list.append(out_dict)
|
|
||||||
state.finished = recv_obj.finished_reasons[i] is not None
|
state.finished = recv_obj.finished_reasons[i] is not None
|
||||||
|
if state.finished:
|
||||||
|
if self.server_args.speculative_algorithm:
|
||||||
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||||
|
state.finished_time = time.time()
|
||||||
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||||
|
|
||||||
|
state.out_list.append(out_dict)
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
|
# Log metrics and dump
|
||||||
if self.enable_metrics and state.obj.log_metrics:
|
if self.enable_metrics and state.obj.log_metrics:
|
||||||
self.collect_metrics(state, recv_obj, i)
|
self.collect_metrics(state, recv_obj, i)
|
||||||
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
||||||
@@ -830,6 +978,7 @@ class TokenizerManager:
|
|||||||
self,
|
self,
|
||||||
meta_info: dict,
|
meta_info: dict,
|
||||||
top_logprobs_num: int,
|
top_logprobs_num: int,
|
||||||
|
token_ids_logprob: List[int],
|
||||||
return_text_in_logprobs: bool,
|
return_text_in_logprobs: bool,
|
||||||
recv_obj: BatchStrOut,
|
recv_obj: BatchStrOut,
|
||||||
recv_obj_index: int,
|
recv_obj_index: int,
|
||||||
@@ -857,6 +1006,20 @@ class TokenizerManager:
|
|||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if token_ids_logprob is not None:
|
||||||
|
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||||
|
recv_obj.input_token_ids_logprobs_val[recv_obj_index],
|
||||||
|
recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
|
||||||
|
return_text_in_logprobs,
|
||||||
|
)
|
||||||
|
meta_info["output_token_ids_logprobs"] = (
|
||||||
|
self.detokenize_top_logprobs_tokens(
|
||||||
|
recv_obj.output_token_ids_logprobs_val[recv_obj_index],
|
||||||
|
recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
|
||||||
|
return_text_in_logprobs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def detokenize_logprob_tokens(
|
def detokenize_logprob_tokens(
|
||||||
self,
|
self,
|
||||||
token_logprobs_val: List[float],
|
token_logprobs_val: List[float],
|
||||||
@@ -900,33 +1063,29 @@ class TokenizerManager:
|
|||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.first_token_time is None:
|
if state.first_token_time == 0.0:
|
||||||
state.first_token_time = time.time()
|
state.first_token_time = state.last_time = time.time()
|
||||||
|
state.last_completion_tokens = completion_tokens
|
||||||
self.metrics_collector.observe_time_to_first_token(
|
self.metrics_collector.observe_time_to_first_token(
|
||||||
state.first_token_time - state.created_time
|
state.first_token_time - state.created_time
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if completion_tokens >= 2:
|
num_new_tokens = completion_tokens - state.last_completion_tokens
|
||||||
# Compute time_per_output_token for the streaming case
|
if num_new_tokens:
|
||||||
self.metrics_collector.observe_time_per_output_token(
|
new_time = time.time()
|
||||||
(time.time() - state.first_token_time) / (completion_tokens - 1)
|
interval = new_time - state.last_time
|
||||||
|
self.metrics_collector.observe_inter_token_latency(
|
||||||
|
interval,
|
||||||
|
num_new_tokens,
|
||||||
)
|
)
|
||||||
|
state.last_time = new_time
|
||||||
|
state.last_completion_tokens = completion_tokens
|
||||||
|
|
||||||
if state.finished:
|
if state.finished:
|
||||||
self.metrics_collector.observe_one_finished_request(
|
self.metrics_collector.observe_one_finished_request(
|
||||||
recv_obj.prompt_tokens[i], completion_tokens
|
recv_obj.prompt_tokens[i],
|
||||||
)
|
completion_tokens,
|
||||||
self.metrics_collector.observe_e2e_request_latency(
|
state.finished_time - state.created_time,
|
||||||
time.time() - state.created_time
|
|
||||||
)
|
|
||||||
# Compute time_per_output_token for the non-streaming case
|
|
||||||
if (
|
|
||||||
hasattr(state.obj, "stream")
|
|
||||||
and not state.obj.stream
|
|
||||||
and completion_tokens >= 1
|
|
||||||
):
|
|
||||||
self.metrics_collector.observe_time_per_output_token(
|
|
||||||
(time.time() - state.created_time) / completion_tokens
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def dump_requests(self, state: ReqState, out_dict: dict):
|
def dump_requests(self, state: ReqState, out_dict: dict):
|
||||||
@@ -996,22 +1155,38 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
class _Communicator(Generic[T]):
|
class _Communicator(Generic[T]):
|
||||||
|
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
||||||
|
|
||||||
def __init__(self, sender, fan_out: int):
|
def __init__(self, sender, fan_out: int):
|
||||||
self._sender = sender
|
self._sender = sender
|
||||||
self._fan_out = fan_out
|
self._fan_out = fan_out
|
||||||
self._result_future: Optional[asyncio.Future] = None
|
self._result_event: Optional[asyncio.Event] = None
|
||||||
self._result_values: Optional[List[T]] = None
|
self._result_values: Optional[List[T]] = None
|
||||||
|
self._ready_queue: Deque[asyncio.Future] = deque()
|
||||||
|
|
||||||
async def __call__(self, obj):
|
async def __call__(self, obj):
|
||||||
|
ready_event = asyncio.Event()
|
||||||
|
if self._result_event is not None or len(self._ready_queue) > 0:
|
||||||
|
self._ready_queue.append(ready_event)
|
||||||
|
await ready_event.wait()
|
||||||
|
assert self._result_event is None
|
||||||
|
assert self._result_values is None
|
||||||
|
|
||||||
|
if obj:
|
||||||
self._sender.send_pyobj(obj)
|
self._sender.send_pyobj(obj)
|
||||||
self._result_future = asyncio.Future()
|
|
||||||
|
self._result_event = asyncio.Event()
|
||||||
self._result_values = []
|
self._result_values = []
|
||||||
await self._result_future
|
await self._result_event.wait()
|
||||||
result_values = self._result_values
|
result_values = self._result_values
|
||||||
self._result_future = self._result_values = None
|
self._result_event = self._result_values = None
|
||||||
|
|
||||||
|
if len(self._ready_queue) > 0:
|
||||||
|
self._ready_queue.popleft().set()
|
||||||
|
|
||||||
return result_values
|
return result_values
|
||||||
|
|
||||||
def handle_recv(self, recv_obj: T):
|
def handle_recv(self, recv_obj: T):
|
||||||
self._result_values.append(recv_obj)
|
self._result_values.append(recv_obj)
|
||||||
if len(self._result_values) == self._fan_out:
|
if len(self._result_values) == self._fan_out:
|
||||||
self._result_future.set_result(None)
|
self._result_event.set()
|
||||||
|
|||||||
@@ -15,10 +15,13 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
@@ -159,7 +162,7 @@ class TpModelWorker:
|
|||||||
model_worker_batch: ModelWorkerBatch,
|
model_worker_batch: ModelWorkerBatch,
|
||||||
launch_done: Optional[threading.Event] = None,
|
launch_done: Optional[threading.Event] = None,
|
||||||
skip_sample: bool = False,
|
skip_sample: bool = False,
|
||||||
):
|
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
if launch_done:
|
if launch_done:
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ class TpModelWorkerClient:
|
|||||||
logits_output.next_token_logprobs.tolist()
|
logits_output.next_token_logprobs.tolist()
|
||||||
)
|
)
|
||||||
if logits_output.input_token_logprobs is not None:
|
if logits_output.input_token_logprobs is not None:
|
||||||
logits_output.input_token_logprobs = (
|
logits_output.input_token_logprobs = tuple(
|
||||||
logits_output.input_token_logprobs.tolist()
|
logits_output.input_token_logprobs.tolist()
|
||||||
)
|
)
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
@@ -188,8 +188,7 @@ class TpModelWorkerClient:
|
|||||||
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
||||||
sampling_info,
|
sampling_info,
|
||||||
sampling_info_done=threading.Event(),
|
sampling_info_done=threading.Event(),
|
||||||
scaling_penalties=sampling_info.scaling_penalties,
|
penalizer_orchestrator=None,
|
||||||
linear_penalties=sampling_info.linear_penalties,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
@@ -12,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class ChunkCacheEntry:
|
class ChunkCacheEntry:
|
||||||
def __init__(self, rid, value):
|
def __init__(self, rid: str, value: torch.Tensor):
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache):
|
|||||||
self.disable = True
|
self.disable = True
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool = token_to_kv_pool
|
self.token_to_kv_pool = token_to_kv_pool
|
||||||
|
self.entries: Dict[str, ChunkCacheEntry] = {}
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache):
|
|||||||
if req.rid in self.entries:
|
if req.rid in self.entries:
|
||||||
del self.entries[req.rid]
|
del self.entries[req.rid]
|
||||||
|
|
||||||
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
def cache_unfinished_req(self, req: Req):
|
||||||
if token_ids is None:
|
|
||||||
token_id_len = len(req.fill_ids)
|
token_id_len = len(req.fill_ids)
|
||||||
else:
|
|
||||||
token_id_len = len(token_ids)
|
|
||||||
|
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, :token_id_len
|
req.req_pool_idx, :token_id_len
|
||||||
@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache):
|
|||||||
def evictable_size(self):
|
def evictable_size(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def pretty_print(self):
|
||||||
|
return ""
|
||||||
|
|
||||||
def protected_size(self):
|
def protected_size(self):
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Utilities for Prometheus Metrics Collection."""
|
"""Utilities for Prometheus Metrics Collection."""
|
||||||
|
|
||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
@@ -35,19 +36,20 @@ class SchedulerMetricsCollector:
|
|||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
self.last_log_time = time.time()
|
||||||
|
|
||||||
self.num_running_reqs = Gauge(
|
self.num_running_reqs = Gauge(
|
||||||
name="sglang:num_running_reqs",
|
name="sglang:num_running_reqs",
|
||||||
documentation="The number of running requests.",
|
documentation="The number of running requests.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
multiprocess_mode="sum",
|
multiprocess_mode="mostrecent",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_used_tokens = Gauge(
|
self.num_used_tokens = Gauge(
|
||||||
name="sglang:num_used_tokens",
|
name="sglang:num_used_tokens",
|
||||||
documentation="The number of used tokens.",
|
documentation="The number of used tokens.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
multiprocess_mode="sum",
|
multiprocess_mode="mostrecent",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.token_usage = Gauge(
|
self.token_usage = Gauge(
|
||||||
@@ -61,14 +63,14 @@ class SchedulerMetricsCollector:
|
|||||||
name="sglang:gen_throughput",
|
name="sglang:gen_throughput",
|
||||||
documentation="The generation throughput (token/s).",
|
documentation="The generation throughput (token/s).",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
multiprocess_mode="sum",
|
multiprocess_mode="mostrecent",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_queue_reqs = Gauge(
|
self.num_queue_reqs = Gauge(
|
||||||
name="sglang:num_queue_reqs",
|
name="sglang:num_queue_reqs",
|
||||||
documentation="The number of requests in the waiting queue.",
|
documentation="The number of requests in the waiting queue.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
multiprocess_mode="sum",
|
multiprocess_mode="mostrecent",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache_hit_rate = Gauge(
|
self.cache_hit_rate = Gauge(
|
||||||
@@ -97,6 +99,7 @@ class SchedulerMetricsCollector:
|
|||||||
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
||||||
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
||||||
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
||||||
|
self.last_log_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
class TokenizerMetricsCollector:
|
class TokenizerMetricsCollector:
|
||||||
@@ -130,12 +133,15 @@ class TokenizerMetricsCollector:
|
|||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
buckets=[
|
buckets=[
|
||||||
0.1,
|
0.1,
|
||||||
0.25,
|
0.3,
|
||||||
0.5,
|
0.5,
|
||||||
0.75,
|
0.7,
|
||||||
|
0.9,
|
||||||
1,
|
1,
|
||||||
2,
|
2,
|
||||||
5,
|
4,
|
||||||
|
6,
|
||||||
|
8,
|
||||||
10,
|
10,
|
||||||
20,
|
20,
|
||||||
40,
|
40,
|
||||||
@@ -151,24 +157,56 @@ class TokenizerMetricsCollector:
|
|||||||
documentation="Histogram of time per output token in seconds.",
|
documentation="Histogram of time per output token in seconds.",
|
||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
buckets=[
|
buckets=[
|
||||||
|
0.002,
|
||||||
0.005,
|
0.005,
|
||||||
0.01,
|
0.010,
|
||||||
|
0.020,
|
||||||
|
0.030,
|
||||||
|
0.040,
|
||||||
|
0.050,
|
||||||
|
0.060,
|
||||||
|
0.070,
|
||||||
|
0.080,
|
||||||
|
0.090,
|
||||||
|
0.100,
|
||||||
|
0.150,
|
||||||
|
0.200,
|
||||||
|
0.300,
|
||||||
|
0.400,
|
||||||
|
0.600,
|
||||||
|
0.800,
|
||||||
|
1.000,
|
||||||
|
2.000,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_inter_token_latency_seconds = Histogram(
|
||||||
|
name="sglang:inter_token_latency_seconds",
|
||||||
|
documentation="Histogram of inter-token latency in seconds.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.002,
|
||||||
|
0.004,
|
||||||
|
0.006,
|
||||||
|
0.008,
|
||||||
|
0.010,
|
||||||
0.015,
|
0.015,
|
||||||
0.02,
|
0.020,
|
||||||
0.025,
|
0.025,
|
||||||
0.03,
|
0.030,
|
||||||
0.04,
|
0.035,
|
||||||
0.05,
|
0.040,
|
||||||
|
0.050,
|
||||||
0.075,
|
0.075,
|
||||||
0.1,
|
0.100,
|
||||||
0.15,
|
0.150,
|
||||||
0.2,
|
0.200,
|
||||||
0.3,
|
0.300,
|
||||||
0.4,
|
0.400,
|
||||||
0.5,
|
0.500,
|
||||||
0.75,
|
0.750,
|
||||||
1.0,
|
1.000,
|
||||||
2.5,
|
2.000,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -178,8 +216,9 @@ class TokenizerMetricsCollector:
|
|||||||
labelnames=labels.keys(),
|
labelnames=labels.keys(),
|
||||||
buckets=[
|
buckets=[
|
||||||
0.1,
|
0.1,
|
||||||
0.25,
|
0.2,
|
||||||
0.5,
|
0.4,
|
||||||
|
0.8,
|
||||||
1,
|
1,
|
||||||
2,
|
2,
|
||||||
5,
|
5,
|
||||||
@@ -188,28 +227,161 @@ class TokenizerMetricsCollector:
|
|||||||
40,
|
40,
|
||||||
60,
|
60,
|
||||||
80,
|
80,
|
||||||
|
100,
|
||||||
|
150,
|
||||||
|
200,
|
||||||
|
250,
|
||||||
|
300,
|
||||||
|
350,
|
||||||
|
500,
|
||||||
|
1000,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_prefill_prealloc_duration = Histogram(
|
||||||
|
name="sglang:prefill_prealloc_duration_seconds",
|
||||||
|
documentation="Histogram of prefill prealloc duration in seconds.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.1,
|
||||||
|
0.3,
|
||||||
|
0.5,
|
||||||
|
0.7,
|
||||||
|
0.9,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
6,
|
||||||
|
8,
|
||||||
|
10,
|
||||||
|
20,
|
||||||
|
40,
|
||||||
|
60,
|
||||||
|
80,
|
||||||
120,
|
120,
|
||||||
160,
|
160,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.histogram_prefill_queue_duration = Histogram(
|
||||||
|
name="sglang:prefill_queue_duration_seconds",
|
||||||
|
documentation="Histogram of prefill queue duration in seconds.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.1,
|
||||||
|
0.3,
|
||||||
|
0.5,
|
||||||
|
0.7,
|
||||||
|
0.9,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
64,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_prefill_forward_duration = Histogram(
|
||||||
|
name="sglang:prefill_forward_duration_seconds",
|
||||||
|
documentation="Histogram of prefill forward duration in seconds.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.1,
|
||||||
|
0.3,
|
||||||
|
0.5,
|
||||||
|
0.7,
|
||||||
|
0.9,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
64,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_prefill_transfer_duration = Histogram(
|
||||||
|
name="sglang:prefill_transfer_duration_seconds",
|
||||||
|
documentation="Histogram of prefill transfer duration in seconds.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.050,
|
||||||
|
0.100,
|
||||||
|
0.150,
|
||||||
|
0.200,
|
||||||
|
0.300,
|
||||||
|
0.400,
|
||||||
|
0.500,
|
||||||
|
1.000,
|
||||||
|
2.000,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_decode_prealloc_duration = Histogram(
|
||||||
|
name="sglang:decode_prealloc_duration_seconds",
|
||||||
|
documentation="Histogram of decode prealloc duration in seconds.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.1,
|
||||||
|
0.3,
|
||||||
|
0.5,
|
||||||
|
0.7,
|
||||||
|
0.9,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
64,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.histogram_decode_queue_duration = Histogram(
|
||||||
|
name="sglang:decode_queue_duration_seconds",
|
||||||
|
documentation="Histogram of decode queue duration in seconds.",
|
||||||
|
labelnames=labels.keys(),
|
||||||
|
buckets=[
|
||||||
|
0.1,
|
||||||
|
0.3,
|
||||||
|
0.5,
|
||||||
|
0.7,
|
||||||
|
0.9,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
64,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||||
histogram.labels(**self.labels).observe(data)
|
histogram.labels(**self.labels).observe(data)
|
||||||
|
|
||||||
def _log_counter(self, counter, data: Union[int, float]) -> None:
|
def observe_one_finished_request(
|
||||||
# Convenience function for logging to counter.
|
self,
|
||||||
counter.labels(**self.labels).inc(data)
|
prompt_tokens: int,
|
||||||
|
generation_tokens: int,
|
||||||
def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
|
e2e_latency: float,
|
||||||
|
):
|
||||||
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
||||||
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
||||||
self.num_requests_total.labels(**self.labels).inc(1)
|
self.num_requests_total.labels(**self.labels).inc(1)
|
||||||
|
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
||||||
|
if generation_tokens >= 1:
|
||||||
|
self.histogram_time_per_output_token.labels(**self.labels).observe(
|
||||||
|
e2e_latency / generation_tokens
|
||||||
|
)
|
||||||
|
|
||||||
def observe_time_to_first_token(self, value: Union[float, int]):
|
def observe_time_to_first_token(self, value: float):
|
||||||
self._log_histogram(self.histogram_time_to_first_token, value)
|
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
|
||||||
|
|
||||||
def observe_time_per_output_token(self, value: Union[float, int]):
|
def observe_inter_token_latency(self, internval: float, num_new_tokens: int):
|
||||||
self._log_histogram(self.histogram_time_per_output_token, value)
|
adjusted_interval = internval / num_new_tokens
|
||||||
|
|
||||||
def observe_e2e_request_latency(self, value: Union[float, int]):
|
# A faster version of the Histogram::observe which observes multiple values at the same time.
|
||||||
self._log_histogram(self.histogram_e2e_request_latency, value)
|
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
|
||||||
|
his = self.histogram_inter_token_latency_seconds.labels(**self.labels)
|
||||||
|
his._sum.inc(internval)
|
||||||
|
|
||||||
|
for i, bound in enumerate(his._upper_bounds):
|
||||||
|
if adjusted_interval <= bound:
|
||||||
|
his._buckets[i].inc(num_new_tokens)
|
||||||
|
break
|
||||||
|
|||||||
@@ -109,11 +109,15 @@ def set_torch_compile_config():
|
|||||||
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||||
server_args = model_runner.server_args
|
server_args = model_runner.server_args
|
||||||
capture_bs = server_args.cuda_graph_bs
|
capture_bs = server_args.cuda_graph_bs
|
||||||
|
|
||||||
if capture_bs is None:
|
if capture_bs is None:
|
||||||
|
if server_args.speculative_algorithm is None:
|
||||||
if server_args.disable_cuda_graph_padding:
|
if server_args.disable_cuda_graph_padding:
|
||||||
capture_bs = list(range(1, 33)) + [64, 128]
|
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
|
||||||
else:
|
else:
|
||||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||||
|
else:
|
||||||
|
capture_bs = list(range(1, 33))
|
||||||
|
|
||||||
if is_hip_:
|
if is_hip_:
|
||||||
capture_bs += [i * 8 for i in range(21, 33)]
|
capture_bs += [i * 8 for i in range(21, 33)]
|
||||||
@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
capture_bs = [
|
capture_bs = [
|
||||||
bs
|
bs
|
||||||
for bs in capture_bs
|
for bs in capture_bs
|
||||||
@@ -388,9 +393,6 @@ class CudaGraphRunner:
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.model_runner.tp_group.barrier()
|
self.model_runner.tp_group.barrier()
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.model_runner.tp_group.barrier()
|
|
||||||
|
|
||||||
global global_graph_memory_pool
|
global global_graph_memory_pool
|
||||||
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
||||||
out = run_once()
|
out = run_once()
|
||||||
@@ -401,12 +403,11 @@ class CudaGraphRunner:
|
|||||||
global_graph_memory_pool = graph.pool()
|
global_graph_memory_pool = graph.pool()
|
||||||
return graph, out
|
return graph, out
|
||||||
|
|
||||||
def replay(self, forward_batch: ForwardBatch):
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||||
assert forward_batch.out_cache_loc is not None
|
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||||
hidden_mode_from_spec_info = getattr(
|
hidden_mode_from_spec_info = getattr(
|
||||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||||
)
|
)
|
||||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
|
||||||
if (
|
if (
|
||||||
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
|
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
|
||||||
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||||
@@ -420,6 +421,9 @@ class CudaGraphRunner:
|
|||||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||||
self.capture()
|
self.capture()
|
||||||
|
|
||||||
|
def replay(self, forward_batch: ForwardBatch):
|
||||||
|
self.recapture_if_needed(forward_batch)
|
||||||
|
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@@ -46,7 +46,8 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
|
|||||||
|
|
||||||
class CaptureHiddenMode(IntEnum):
|
class CaptureHiddenMode(IntEnum):
|
||||||
NULL = auto()
|
NULL = auto()
|
||||||
|
# Capture hidden states of all tokens.
|
||||||
FULL = auto()
|
FULL = auto()
|
||||||
|
# Capture a hidden state of the last token.
|
||||||
LAST = auto()
|
LAST = auto()
|
||||||
|
|
||||||
def need_capture(self):
|
def need_capture(self):
|
||||||
@@ -148,6 +151,7 @@ class ForwardBatch:
|
|||||||
# For logprob
|
# For logprob
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
top_logprobs_nums: Optional[List[int]] = None
|
top_logprobs_nums: Optional[List[int]] = None
|
||||||
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||||
|
|
||||||
# Position information
|
# Position information
|
||||||
positions: torch.Tensor = None
|
positions: torch.Tensor = None
|
||||||
@@ -160,6 +164,7 @@ class ForwardBatch:
|
|||||||
extend_prefix_lens_cpu: Optional[List[int]] = None
|
extend_prefix_lens_cpu: Optional[List[int]] = None
|
||||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||||
|
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# For multimodal
|
# For multimodal
|
||||||
image_inputs: Optional[List[ImageInputs]] = None
|
image_inputs: Optional[List[ImageInputs]] = None
|
||||||
@@ -190,10 +195,13 @@ class ForwardBatch:
|
|||||||
can_run_dp_cuda_graph: bool = False
|
can_run_dp_cuda_graph: bool = False
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_info: SpecInfo = None
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
capture_hidden_mode: CaptureHiddenMode = None
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
|
|
||||||
|
# For padding
|
||||||
|
padded_static_len: int = -1 # -1 if not padded
|
||||||
|
|
||||||
# For Qwen2-VL
|
# For Qwen2-VL
|
||||||
mrope_positions: torch.Tensor = None
|
mrope_positions: torch.Tensor = None
|
||||||
|
|
||||||
@@ -203,8 +211,13 @@ class ForwardBatch:
|
|||||||
batch: ModelWorkerBatch,
|
batch: ModelWorkerBatch,
|
||||||
model_runner: ModelRunner,
|
model_runner: ModelRunner,
|
||||||
):
|
):
|
||||||
|
|
||||||
device = model_runner.device
|
device = model_runner.device
|
||||||
|
extend_input_logprob_token_ids_gpu = None
|
||||||
|
if batch.extend_input_logprob_token_ids is not None:
|
||||||
|
extend_input_logprob_token_ids_gpu = (
|
||||||
|
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
||||||
|
)
|
||||||
|
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
batch_size=len(batch.seq_lens),
|
batch_size=len(batch.seq_lens),
|
||||||
@@ -220,6 +233,7 @@ class ForwardBatch:
|
|||||||
seq_lens_sum=batch.seq_lens_sum,
|
seq_lens_sum=batch.seq_lens_sum,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
|
token_ids_logprobs=batch.token_ids_logprobs,
|
||||||
global_num_tokens=batch.global_num_tokens,
|
global_num_tokens=batch.global_num_tokens,
|
||||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||||
lora_paths=batch.lora_paths,
|
lora_paths=batch.lora_paths,
|
||||||
@@ -231,6 +245,7 @@ class ForwardBatch:
|
|||||||
spec_info=batch.spec_info,
|
spec_info=batch.spec_info,
|
||||||
capture_hidden_mode=batch.capture_hidden_mode,
|
capture_hidden_mode=batch.capture_hidden_mode,
|
||||||
input_embeds=batch.input_embeds,
|
input_embeds=batch.input_embeds,
|
||||||
|
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
if ret.global_num_tokens is not None:
|
if ret.global_num_tokens is not None:
|
||||||
@@ -341,6 +356,7 @@ class ForwardBatch:
|
|||||||
)
|
)
|
||||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||||
mrope_positions_list[i] = mrope_positions
|
mrope_positions_list[i] = mrope_positions
|
||||||
|
|
||||||
self.mrope_positions = torch.concat(
|
self.mrope_positions = torch.concat(
|
||||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||||
axis=1,
|
axis=1,
|
||||||
@@ -379,7 +395,7 @@ def compute_position_kernel(
|
|||||||
extend_seq_lens,
|
extend_seq_lens,
|
||||||
):
|
):
|
||||||
BLOCK_SIZE: tl.constexpr = 512
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
pid = tl.program_id(0)
|
pid = tl.program_id(0).to(tl.int64)
|
||||||
|
|
||||||
prefix_len = tl.load(extend_prefix_lens + pid)
|
prefix_len = tl.load(extend_prefix_lens + pid)
|
||||||
seq_len = tl.load(extend_seq_lens + pid)
|
seq_len = tl.load(extend_seq_lens + pid)
|
||||||
|
|||||||
@@ -13,9 +13,12 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import datetime
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
@@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader import get_model
|
from sglang.srt.model_loader import get_model
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
@@ -73,10 +77,15 @@ from sglang.srt.utils import (
|
|||||||
set_cpu_offload_max_bytes,
|
set_cpu_offload_max_bytes,
|
||||||
set_cuda_arch,
|
set_cuda_arch,
|
||||||
)
|
)
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||||
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
|
|
||||||
@@ -180,9 +189,13 @@ class ModelRunner:
|
|||||||
"enable_dp_attention": server_args.enable_dp_attention,
|
"enable_dp_attention": server_args.enable_dp_attention,
|
||||||
"enable_ep_moe": server_args.enable_ep_moe,
|
"enable_ep_moe": server_args.enable_ep_moe,
|
||||||
"device": server_args.device,
|
"device": server_args.device,
|
||||||
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||||
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||||
"disable_radix_cache": server_args.disable_radix_cache,
|
"disable_radix_cache": server_args.disable_radix_cache,
|
||||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||||
|
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||||
|
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -199,6 +212,18 @@ class ModelRunner:
|
|||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
|
# Handle the case where some of models don't finish loading.
|
||||||
|
try:
|
||||||
|
dist.monitored_barrier(
|
||||||
|
group=get_tp_group().cpu_group,
|
||||||
|
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
||||||
|
wait_all_ranks=True,
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
raise ValueError(
|
||||||
|
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
||||||
|
) from None
|
||||||
|
|
||||||
# Apply torchao quantization
|
# Apply torchao quantization
|
||||||
torchao_applied = getattr(self.model, "torchao_applied", False)
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
||||||
# In layered loading, torchao may have been applied
|
# In layered loading, torchao may have been applied
|
||||||
@@ -625,6 +650,9 @@ class ModelRunner:
|
|||||||
4096,
|
4096,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if SGLANG_CI_SMALL_KV_SIZE:
|
||||||
|
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
||||||
|
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
if self.is_draft_worker:
|
if self.is_draft_worker:
|
||||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||||
@@ -655,6 +683,7 @@ class ModelRunner:
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.model_config.attention_arch == AttentionArch.MLA
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
and not self.server_args.disable_mla
|
and not self.server_args.disable_mla
|
||||||
@@ -758,9 +787,13 @@ class ModelRunner:
|
|||||||
return
|
return
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
logger.info(
|
||||||
|
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
logger.info(
|
||||||
|
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
def apply_torch_tp(self):
|
def apply_torch_tp(self):
|
||||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||||
@@ -820,11 +853,10 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
||||||
|
|
||||||
def sample(
|
def _preprocess_logits(
|
||||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
||||||
) -> torch.Tensor:
|
):
|
||||||
# Apply logit bias
|
# Apply logit bias
|
||||||
sampling_info = forward_batch.sampling_info
|
|
||||||
if sampling_info.sampling_info_done:
|
if sampling_info.sampling_info_done:
|
||||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||||
# in process_batch_result of the last batch.
|
# in process_batch_result of the last batch.
|
||||||
@@ -833,15 +865,77 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||||
sampling_info.update_regex_vocab_mask()
|
sampling_info.update_regex_vocab_mask()
|
||||||
sampling_info.update_penalties()
|
|
||||||
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||||
|
|
||||||
|
def update_output_logprobs(
|
||||||
|
self,
|
||||||
|
logits_output: LogitsProcessorOutput,
|
||||||
|
sampling_info: SamplingBatchInfo,
|
||||||
|
top_logprobs_nums: List[int],
|
||||||
|
token_ids_logprobs: List[int],
|
||||||
|
next_token_ids: torch.Tensor,
|
||||||
|
*,
|
||||||
|
num_tokens_per_req: List[int],
|
||||||
|
):
|
||||||
|
"""Update the logits_output's output logprob based on next_token_ids
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits_output: The logits output from the model forward
|
||||||
|
sampling_info: Sampling info for logprob calculation
|
||||||
|
top_logprobs_nums: Number of logprobs per request.
|
||||||
|
next_token_ids: Next token ids.
|
||||||
|
num_tokens_per_req: The number of tokens per request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of next_token_ids
|
||||||
|
"""
|
||||||
|
self._preprocess_logits(logits_output, sampling_info)
|
||||||
|
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
||||||
|
top_logprobs_nums_repeat_interleaved = []
|
||||||
|
token_ids_logprobs_repeat_interleaved = []
|
||||||
|
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
||||||
|
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
||||||
|
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
||||||
|
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
||||||
|
self.sampler(
|
||||||
|
logits_output,
|
||||||
|
sampling_info,
|
||||||
|
True,
|
||||||
|
top_logprobs_nums_repeat_interleaved,
|
||||||
|
token_ids_logprobs_repeat_interleaved,
|
||||||
|
batch_next_token_ids=next_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits_output: LogitsProcessorOutput,
|
||||||
|
forward_batch: ForwardBatch,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Sample and compute logprobs and update logits_output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits_output: The logits output from the model forward
|
||||||
|
forward_batch: The forward batch that generates logits_output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of next_token_ids
|
||||||
|
"""
|
||||||
|
# For duplex models with multiple output streams.
|
||||||
|
if isinstance(logits_output, tuple):
|
||||||
|
return torch.stack(
|
||||||
|
[self.sample(values, forward_batch) for values in logits_output],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
||||||
|
|
||||||
# Sample the next tokens
|
# Sample the next tokens
|
||||||
next_token_ids = self.sampler(
|
next_token_ids = self.sampler(
|
||||||
logits_output,
|
logits_output,
|
||||||
sampling_info,
|
forward_batch.sampling_info,
|
||||||
forward_batch.return_logprob,
|
forward_batch.return_logprob,
|
||||||
forward_batch.top_logprobs_nums,
|
forward_batch.top_logprobs_nums,
|
||||||
|
forward_batch.token_ids_logprobs,
|
||||||
)
|
)
|
||||||
return next_token_ids
|
return next_token_ids
|
||||||
|
|
||||||
|
|||||||
@@ -25,10 +25,10 @@ import filelock
|
|||||||
import gguf
|
import gguf
|
||||||
import huggingface_hub.constants
|
import huggingface_hub.constants
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||||
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||||
from safetensors.torch import load_file, safe_open, save_file
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
@@ -62,7 +62,6 @@ enable_hf_transfer()
|
|||||||
|
|
||||||
|
|
||||||
class DisabledTqdm(tqdm):
|
class DisabledTqdm(tqdm):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs, disable=True)
|
super().__init__(*args, **kwargs, disable=True)
|
||||||
|
|
||||||
@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# check if the tensors are the same
|
# check if the tensors are the same
|
||||||
reloaded = load_file(sf_filename)
|
reloaded = safetensors.torch.load_file(sf_filename)
|
||||||
for k in loaded:
|
for k in loaded:
|
||||||
pt_tensor = loaded[k]
|
pt_tensor = loaded[k]
|
||||||
sf_tensor = reloaded[k]
|
sf_tensor = reloaded[k]
|
||||||
@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file(
|
|||||||
def get_quant_config(
|
def get_quant_config(
|
||||||
model_config: ModelConfig, load_config: LoadConfig
|
model_config: ModelConfig, load_config: LoadConfig
|
||||||
) -> QuantizationConfig:
|
) -> QuantizationConfig:
|
||||||
|
|
||||||
quant_cls = get_quantization_config(model_config.quantization)
|
quant_cls = get_quantization_config(model_config.quantization)
|
||||||
|
|
||||||
# GGUF doesn't have config file
|
# GGUF doesn't have config file
|
||||||
@@ -402,15 +400,34 @@ def np_cache_weights_iterator(
|
|||||||
yield name, torch.from_numpy(param)
|
yield name, torch.from_numpy(param)
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt(fn, key):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def safetensors_encrypted_weights_iterator(
|
||||||
|
hf_weights_files: List[str],
|
||||||
|
is_all_weights_sharded: bool = False,
|
||||||
|
decryption_key: Optional[str] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def safetensors_weights_iterator(
|
def safetensors_weights_iterator(
|
||||||
hf_weights_files: List[str],
|
hf_weights_files: List[str],
|
||||||
is_all_weights_sharded: bool = False,
|
is_all_weights_sharded: bool = False,
|
||||||
|
decryption_key: Optional[str] = None,
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
"""Iterate over the weights in the model safetensor files.
|
"""Iterate over the weights in the model safetensor files.
|
||||||
|
|
||||||
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
||||||
entire file instead of reading each tensor one by one.
|
entire file instead of reading each tensor one by one.
|
||||||
"""
|
"""
|
||||||
|
if decryption_key:
|
||||||
|
yield from safetensors_encrypted_weights_iterator(
|
||||||
|
hf_weights_files, is_all_weights_sharded, decryption_key
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
enable_tqdm = (
|
enable_tqdm = (
|
||||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||||
)
|
)
|
||||||
@@ -420,13 +437,7 @@ def safetensors_weights_iterator(
|
|||||||
disable=not enable_tqdm,
|
disable=not enable_tqdm,
|
||||||
bar_format=_BAR_FORMAT,
|
bar_format=_BAR_FORMAT,
|
||||||
):
|
):
|
||||||
if not is_all_weights_sharded:
|
result = safetensors.torch.load_file(st_file, device="cpu")
|
||||||
with safe_open(st_file, framework="pt") as f:
|
|
||||||
for name in f.keys(): # noqa: SIM118
|
|
||||||
param = f.get_tensor(name)
|
|
||||||
yield name, param
|
|
||||||
else:
|
|
||||||
result = load_file(st_file, device="cpu")
|
|
||||||
for name, param in result.items():
|
for name, param in result.items():
|
||||||
yield name, param
|
yield name, param
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
from .orchestrator import BatchedPenalizerOrchestrator
|
from sglang.srt.sampling.penaltylib.frequency_penalty import BatchedFrequencyPenalizer
|
||||||
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
|
from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer
|
||||||
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
|
from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator
|
||||||
from .penalizers.presence_penalty import BatchedPresencePenalizer
|
from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer
|
||||||
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatchedFrequencyPenalizer",
|
"BatchedFrequencyPenalizer",
|
||||||
"BatchedMinNewTokensPenalizer",
|
"BatchedMinNewTokensPenalizer",
|
||||||
"BatchedPresencePenalizer",
|
"BatchedPresencePenalizer",
|
||||||
"BatchedRepetitionPenalizer",
|
|
||||||
"BatchedPenalizerOrchestrator",
|
"BatchedPenalizerOrchestrator",
|
||||||
]
|
]
|
||||||
|
|||||||
66
python/sglang/srt/sampling/penaltylib/frequency_penalty.py
Normal file
66
python/sglang/srt/sampling/penaltylib/frequency_penalty.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||||
|
BatchedPenalizerOrchestrator,
|
||||||
|
_BatchedPenalizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||||
|
"""
|
||||||
|
Frequency penalizer penalizes tokens based on their frequency in the output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||||
|
self.orchestrator = orchestrator
|
||||||
|
self._is_prepared = False
|
||||||
|
|
||||||
|
def _is_required(self) -> bool:
|
||||||
|
return any(
|
||||||
|
req.sampling_params.frequency_penalty != 0.0
|
||||||
|
for req in self.orchestrator.reqs()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
self.cumulated_frequency_penalties = torch.zeros(
|
||||||
|
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.orchestrator.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.frequency_penalties = (
|
||||||
|
torch.tensor(
|
||||||
|
data=[
|
||||||
|
req.sampling_params.frequency_penalty
|
||||||
|
for req in self.orchestrator.reqs()
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.orchestrator.device,
|
||||||
|
)
|
||||||
|
).unsqueeze_(1)
|
||||||
|
|
||||||
|
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||||
|
self.cumulated_frequency_penalties.scatter_add_(
|
||||||
|
dim=1,
|
||||||
|
index=output_ids.unsqueeze(1),
|
||||||
|
src=self.frequency_penalties,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
logits.sub_(self.cumulated_frequency_penalties)
|
||||||
|
|
||||||
|
def _filter(self, keep_indices: torch.Tensor):
|
||||||
|
self.frequency_penalties = self.frequency_penalties[keep_indices]
|
||||||
|
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||||
|
keep_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
||||||
|
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
|
||||||
|
self.frequency_penalties = torch.cat(
|
||||||
|
[self.frequency_penalties, their.frequency_penalties], dim=0
|
||||||
|
)
|
||||||
|
self.cumulated_frequency_penalties = torch.cat(
|
||||||
|
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||||
|
BatchedPenalizerOrchestrator,
|
||||||
|
_BatchedPenalizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||||
@@ -10,9 +11,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|||||||
Min new tokens penalizer penalizes tokens based on the length of the output.
|
Min new tokens penalizer penalizes tokens based on the length of the output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
min_new_tokens: torch.Tensor = None
|
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||||
stop_token_penalties: torch.Tensor = None
|
self.orchestrator = orchestrator
|
||||||
len_output_tokens: torch.Tensor = None
|
self._is_prepared = False
|
||||||
|
|
||||||
def _is_required(self) -> bool:
|
def _is_required(self) -> bool:
|
||||||
return any(
|
return any(
|
||||||
@@ -47,7 +48,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|||||||
padding_value=self.orchestrator.vocab_size,
|
padding_value=self.orchestrator.vocab_size,
|
||||||
)
|
)
|
||||||
self.stop_token_penalties = torch.zeros(
|
self.stop_token_penalties = torch.zeros(
|
||||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
size=(len(self.orchestrator.reqs()), self.orchestrator.vocab_size + 1),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.orchestrator.device,
|
device=self.orchestrator.device,
|
||||||
).scatter_add_(
|
).scatter_add_(
|
||||||
@@ -64,31 +65,22 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.len_output_tokens = torch.zeros(
|
self.len_output_tokens = torch.zeros(
|
||||||
size=(self.orchestrator.batch_size(), 1),
|
size=(len(self.orchestrator.reqs()), 1),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.orchestrator.device,
|
device=self.orchestrator.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _teardown(self):
|
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||||
self.min_new_tokens = None
|
|
||||||
self.stop_token_penalties = None
|
|
||||||
self.len_output_tokens = None
|
|
||||||
|
|
||||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
|
||||||
self.len_output_tokens += 1
|
self.len_output_tokens += 1
|
||||||
|
|
||||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
def _apply(self, logits: torch.Tensor):
|
||||||
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
|
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
|
||||||
logits[mask] += self.stop_token_penalties[mask]
|
logits[mask] += self.stop_token_penalties[mask]
|
||||||
return logits
|
|
||||||
|
|
||||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
def _filter(self, keep_indices: torch.Tensor):
|
||||||
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
self.min_new_tokens = self.min_new_tokens[keep_indices]
|
||||||
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
self.stop_token_penalties = self.stop_token_penalties[keep_indices]
|
||||||
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
self.len_output_tokens = self.len_output_tokens[keep_indices]
|
||||||
|
|
||||||
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
|
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
|
||||||
self.min_new_tokens = torch.cat(
|
self.min_new_tokens = torch.cat(
|
||||||
@@ -1,35 +1,25 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import dataclasses
|
from typing import TYPE_CHECKING, Set, Type
|
||||||
from typing import List, Set, Type, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
@dataclasses.dataclass
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
class _ReqLike:
|
|
||||||
origin_input_ids: List[int]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class _BatchLike:
|
|
||||||
reqs: List[_ReqLike]
|
|
||||||
|
|
||||||
def batch_size(self):
|
|
||||||
return len(self.reqs)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchedPenalizerOrchestrator:
|
class BatchedPenalizerOrchestrator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
batch: _BatchLike,
|
batch: ScheduleBatch,
|
||||||
device: str,
|
penalizers: Set[Type["_BatchedPenalizer"]],
|
||||||
Penalizers: Set[Type["_BatchedPenalizer"]],
|
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
self.device = device
|
self.device = batch.device
|
||||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
|
||||||
|
|
||||||
is_required = False
|
is_required = False
|
||||||
for penalizer in self.penalizers.values():
|
for penalizer in self.penalizers.values():
|
||||||
@@ -37,31 +27,9 @@ class BatchedPenalizerOrchestrator:
|
|||||||
is_required |= pen_is_required
|
is_required |= pen_is_required
|
||||||
self.is_required = is_required
|
self.is_required = is_required
|
||||||
|
|
||||||
input_ids = [
|
|
||||||
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
|
|
||||||
for req in self.reqs()
|
|
||||||
]
|
|
||||||
if self.is_required:
|
|
||||||
self.cumulate_input_tokens(input_ids=input_ids)
|
|
||||||
|
|
||||||
def reqs(self):
|
def reqs(self):
|
||||||
return self.batch.reqs
|
return self.batch.reqs
|
||||||
|
|
||||||
def batch_size(self):
|
|
||||||
return self.batch.batch_size()
|
|
||||||
|
|
||||||
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
|
|
||||||
"""
|
|
||||||
Feed the input tokens to the penalizers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids (List[torch.Tensor]): The input tokens.
|
|
||||||
"""
|
|
||||||
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
|
||||||
|
|
||||||
for penalizer in self.penalizers.values():
|
|
||||||
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
|
||||||
|
|
||||||
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Feed the output tokens to the penalizers.
|
Feed the output tokens to the penalizers.
|
||||||
@@ -69,13 +37,8 @@ class BatchedPenalizerOrchestrator:
|
|||||||
Args:
|
Args:
|
||||||
output_ids (torch.Tensor): The output tokens.
|
output_ids (torch.Tensor): The output tokens.
|
||||||
"""
|
"""
|
||||||
if not self.is_required:
|
|
||||||
return
|
|
||||||
|
|
||||||
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
|
||||||
|
|
||||||
for penalizer in self.penalizers.values():
|
for penalizer in self.penalizers.values():
|
||||||
penalizer.cumulate_output_tokens(output_ids=token_ids)
|
penalizer.cumulate_output_tokens(output_ids=output_ids)
|
||||||
|
|
||||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -88,48 +51,33 @@ class BatchedPenalizerOrchestrator:
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The logits after applying the penalizers.
|
torch.Tensor: The logits after applying the penalizers.
|
||||||
"""
|
"""
|
||||||
if not self.is_required:
|
|
||||||
return
|
|
||||||
|
|
||||||
for penalizer in self.penalizers.values():
|
for penalizer in self.penalizers.values():
|
||||||
logits = penalizer.apply(logits)
|
penalizer.apply(logits)
|
||||||
|
|
||||||
return logits
|
def filter(self, keep_indices: torch.Tensor):
|
||||||
|
|
||||||
def filter(
|
|
||||||
self,
|
|
||||||
indices_to_keep: List[int],
|
|
||||||
indices_tensor_to_keep: torch.Tensor = None,
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Filter the penalizers based on the indices to keep in the batch.
|
Filter the penalizers based on the indices to keep in the batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
indices_to_keep (List[int]): List of indices to keep in the batch.
|
keep_indices (torch.Tensor): Tensor of indices to keep in the batch.
|
||||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
|
||||||
"""
|
"""
|
||||||
if not self.is_required:
|
if not self.is_required:
|
||||||
return
|
return
|
||||||
|
|
||||||
empty_indices = len(indices_to_keep) == 0
|
if len(keep_indices) == 0:
|
||||||
|
self.is_required = False
|
||||||
|
for penalizer in self.penalizers.values():
|
||||||
|
penalizer.teardown()
|
||||||
|
return
|
||||||
|
|
||||||
is_required = False
|
is_required = False
|
||||||
for penalizer in self.penalizers.values():
|
for penalizer in self.penalizers.values():
|
||||||
tmp_is_required = penalizer.is_required()
|
tmp_is_required = penalizer.is_required()
|
||||||
is_required = is_required or tmp_is_required
|
is_required |= tmp_is_required
|
||||||
if not tmp_is_required or empty_indices:
|
if tmp_is_required:
|
||||||
penalizer.teardown()
|
penalizer.filter(keep_indices=keep_indices)
|
||||||
else:
|
else:
|
||||||
# create tensor index only when it's needed
|
penalizer.teardown()
|
||||||
if indices_tensor_to_keep is None:
|
|
||||||
indices_tensor_to_keep = torch.tensor(
|
|
||||||
indices_to_keep, dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
penalizer.filter(
|
|
||||||
indices_to_keep=indices_to_keep,
|
|
||||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
|
||||||
)
|
|
||||||
self.is_required = is_required
|
self.is_required = is_required
|
||||||
|
|
||||||
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
||||||
@@ -146,75 +94,9 @@ class BatchedPenalizerOrchestrator:
|
|||||||
if not self.is_required and not their.is_required:
|
if not self.is_required and not their.is_required:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.is_required |= their.is_required
|
self.is_required = True
|
||||||
for Penalizer, their_penalizer in their.penalizers.items():
|
for penalizer, their_penalizer in their.penalizers.items():
|
||||||
if Penalizer not in self.penalizers:
|
self.penalizers[penalizer].merge(their_penalizer)
|
||||||
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
|
||||||
|
|
||||||
self.penalizers[Penalizer].merge(their_penalizer)
|
|
||||||
|
|
||||||
|
|
||||||
class _TokenIDs:
|
|
||||||
"""
|
|
||||||
A class that wraps token IDs to provide additional utility functions to penalizers.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
|
||||||
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
|
|
||||||
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
orchestrator: BatchedPenalizerOrchestrator,
|
|
||||||
token_ids: Union[torch.Tensor, List[torch.Tensor]],
|
|
||||||
):
|
|
||||||
self.orchestrator = orchestrator
|
|
||||||
self.token_ids = token_ids
|
|
||||||
self.cached_counts = None
|
|
||||||
|
|
||||||
def occurrence_count(self) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The occurrence count tensor.
|
|
||||||
"""
|
|
||||||
if self.cached_counts is not None:
|
|
||||||
return self.cached_counts
|
|
||||||
|
|
||||||
token_ids = self.token_ids
|
|
||||||
|
|
||||||
if isinstance(token_ids, list):
|
|
||||||
# TODO: optimize this part
|
|
||||||
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
|
||||||
sequences=token_ids,
|
|
||||||
batch_first=True,
|
|
||||||
padding_value=self.orchestrator.vocab_size,
|
|
||||||
)
|
|
||||||
self.cached_counts = torch.zeros(
|
|
||||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=self.orchestrator.device,
|
|
||||||
).scatter_add_(
|
|
||||||
dim=1,
|
|
||||||
index=padded_token_ids,
|
|
||||||
src=torch.ones_like(padded_token_ids),
|
|
||||||
)[
|
|
||||||
:, : self.orchestrator.vocab_size
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# TODO: optimize this part. We do not need to create this big tensor every time.
|
|
||||||
# We can directly apply the results on the logits.
|
|
||||||
self.cached_counts = torch.zeros(
|
|
||||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
|
|
||||||
device=self.orchestrator.device,
|
|
||||||
)
|
|
||||||
self.cached_counts[
|
|
||||||
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
|
|
||||||
] = 1
|
|
||||||
|
|
||||||
return self.cached_counts
|
|
||||||
|
|
||||||
|
|
||||||
class _BatchedPenalizer(abc.ABC):
|
class _BatchedPenalizer(abc.ABC):
|
||||||
@@ -222,10 +104,6 @@ class _BatchedPenalizer(abc.ABC):
|
|||||||
An abstract class for a batched penalizer.
|
An abstract class for a batched penalizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
|
||||||
self.orchestrator = orchestrator
|
|
||||||
self._is_prepared = False
|
|
||||||
|
|
||||||
def is_prepared(self) -> bool:
|
def is_prepared(self) -> bool:
|
||||||
return self._is_prepared
|
return self._is_prepared
|
||||||
|
|
||||||
@@ -233,51 +111,40 @@ class _BatchedPenalizer(abc.ABC):
|
|||||||
return self._is_required()
|
return self._is_required()
|
||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
if not self.is_prepared():
|
if not self._is_prepared:
|
||||||
self._prepare()
|
self._prepare()
|
||||||
self._is_prepared = True
|
self._is_prepared = True
|
||||||
|
|
||||||
def prepare_if_required(self):
|
def prepare_if_required(self):
|
||||||
if self.is_required():
|
if self._is_required():
|
||||||
self.prepare()
|
self.prepare()
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def teardown(self):
|
def teardown(self):
|
||||||
if self.is_prepared():
|
|
||||||
self._teardown()
|
|
||||||
self._is_prepared = False
|
self._is_prepared = False
|
||||||
|
|
||||||
def cumulate_input_tokens(self, input_ids: _TokenIDs):
|
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||||
if not self.is_prepared():
|
if not self._is_prepared:
|
||||||
return
|
|
||||||
|
|
||||||
self._cumulate_input_tokens(input_ids=input_ids)
|
|
||||||
|
|
||||||
def cumulate_output_tokens(self, output_ids: _TokenIDs):
|
|
||||||
if not self.is_prepared():
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self._cumulate_output_tokens(output_ids=output_ids)
|
self._cumulate_output_tokens(output_ids=output_ids)
|
||||||
|
|
||||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
if not self.is_prepared():
|
if not self._is_prepared:
|
||||||
return logits
|
|
||||||
|
|
||||||
return self._apply(logits=logits)
|
|
||||||
|
|
||||||
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
|
||||||
if not self.is_prepared():
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self._filter(
|
self._apply(logits=logits)
|
||||||
indices_to_keep=indices_to_keep,
|
|
||||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
def filter(self, keep_indices: torch.Tensor):
|
||||||
)
|
if not self._is_prepared:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._filter(keep_indices=keep_indices)
|
||||||
|
|
||||||
def merge(self, their: "_BatchedPenalizer"):
|
def merge(self, their: "_BatchedPenalizer"):
|
||||||
if not self.is_prepared() and not their.is_prepared():
|
if not self._is_prepared and not their._is_prepared:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.prepare()
|
self.prepare()
|
||||||
@@ -300,23 +167,7 @@ class _BatchedPenalizer(abc.ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _teardown(self):
|
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||||
"""
|
|
||||||
Tear down the penalizer.
|
|
||||||
Usually, this is where the penalizer frees its tensors.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
|
||||||
"""
|
|
||||||
Cumulate the input tokens.
|
|
||||||
Orchestrator will call this function to feed the input tokens to the penalizer.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
|
||||||
"""
|
"""
|
||||||
Cumulate the output tokens.
|
Cumulate the output tokens.
|
||||||
Orchestrator will call this function to feed the output tokens to the penalizer.
|
Orchestrator will call this function to feed the output tokens to the penalizer.
|
||||||
@@ -332,7 +183,7 @@ class _BatchedPenalizer(abc.ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
def _filter(self, keep_indices: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
|
||||||
|
|
||||||
|
|
||||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
||||||
"""
|
|
||||||
Frequency penalizer penalizes tokens based on their frequency in the output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
frequency_penalties: torch.Tensor = None
|
|
||||||
cumulated_frequency_penalties: torch.Tensor = None
|
|
||||||
|
|
||||||
def _is_required(self) -> bool:
|
|
||||||
return any(
|
|
||||||
req.sampling_params.frequency_penalty != 0.0
|
|
||||||
for req in self.orchestrator.reqs()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare(self):
|
|
||||||
self.cumulated_frequency_penalties = (
|
|
||||||
torch.tensor(
|
|
||||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.orchestrator.device,
|
|
||||||
)
|
|
||||||
.unsqueeze_(1)
|
|
||||||
.repeat(1, self.orchestrator.vocab_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.frequency_penalties = (
|
|
||||||
torch.tensor(
|
|
||||||
data=[
|
|
||||||
req.sampling_params.frequency_penalty
|
|
||||||
for req in self.orchestrator.reqs()
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.orchestrator.device,
|
|
||||||
)
|
|
||||||
.unsqueeze_(1)
|
|
||||||
.expand_as(self.cumulated_frequency_penalties)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _teardown(self):
|
|
||||||
self.frequency_penalties = None
|
|
||||||
self.cumulated_frequency_penalties = None
|
|
||||||
|
|
||||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
|
||||||
self.cumulated_frequency_penalties += (
|
|
||||||
self.frequency_penalties * output_ids.occurrence_count()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
||||||
logits -= self.cumulated_frequency_penalties
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
|
||||||
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
|
||||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
|
||||||
indices_tensor_to_keep
|
|
||||||
]
|
|
||||||
|
|
||||||
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
|
||||||
self.frequency_penalties = torch.cat(
|
|
||||||
[self.frequency_penalties, their.frequency_penalties], dim=0
|
|
||||||
)
|
|
||||||
self.cumulated_frequency_penalties = torch.cat(
|
|
||||||
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
|
||||||
|
|
||||||
|
|
||||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
||||||
"""
|
|
||||||
Presence penalizer penalizes tokens based on their presence in the output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
presence_penalties: torch.Tensor = None
|
|
||||||
cumulated_presence_penalties: torch.Tensor = None
|
|
||||||
|
|
||||||
def _is_required(self) -> bool:
|
|
||||||
return any(
|
|
||||||
req.sampling_params.presence_penalty != 0.0
|
|
||||||
for req in self.orchestrator.reqs()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare(self):
|
|
||||||
self.cumulated_presence_penalties = (
|
|
||||||
torch.tensor(
|
|
||||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.orchestrator.device,
|
|
||||||
)
|
|
||||||
.unsqueeze_(1)
|
|
||||||
.repeat(1, self.orchestrator.vocab_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.presence_penalties = (
|
|
||||||
torch.tensor(
|
|
||||||
data=[
|
|
||||||
req.sampling_params.presence_penalty
|
|
||||||
for req in self.orchestrator.reqs()
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.orchestrator.device,
|
|
||||||
)
|
|
||||||
.unsqueeze_(1)
|
|
||||||
.expand_as(self.cumulated_presence_penalties)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _teardown(self):
|
|
||||||
self.presence_penalties = None
|
|
||||||
self.cumulated_presence_penalties = None
|
|
||||||
|
|
||||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
|
||||||
mask = output_ids.occurrence_count() > 0
|
|
||||||
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
|
|
||||||
|
|
||||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
||||||
logits -= self.cumulated_presence_penalties
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
|
||||||
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
|
||||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
|
||||||
indices_tensor_to_keep
|
|
||||||
]
|
|
||||||
|
|
||||||
def _merge(self, their: "BatchedPresencePenalizer"):
|
|
||||||
self.presence_penalties = torch.cat(
|
|
||||||
[self.presence_penalties, their.presence_penalties], dim=0
|
|
||||||
)
|
|
||||||
self.cumulated_presence_penalties = torch.cat(
|
|
||||||
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
|
||||||
from sglang.srt.utils import get_compiler_backend
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
||||||
def apply_scaling_penalties(logits, scaling_penalties):
|
|
||||||
logits[:] = torch.where(
|
|
||||||
logits > 0,
|
|
||||||
logits / scaling_penalties,
|
|
||||||
logits * scaling_penalties,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
||||||
"""
|
|
||||||
Repetition penalizer penalizes tokens based on their repetition in the input and output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
repetition_penalties: torch.Tensor = None
|
|
||||||
cumulated_repetition_penalties: torch.Tensor = None
|
|
||||||
|
|
||||||
def _is_required(self) -> bool:
|
|
||||||
return any(
|
|
||||||
req.sampling_params.repetition_penalty != 1.0
|
|
||||||
for req in self.orchestrator.reqs()
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare(self):
|
|
||||||
self.cumulated_repetition_penalties = (
|
|
||||||
torch.tensor(
|
|
||||||
data=[1.0 for _ in self.orchestrator.reqs()],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.orchestrator.device,
|
|
||||||
)
|
|
||||||
.unsqueeze_(1)
|
|
||||||
.repeat(1, self.orchestrator.vocab_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.repetition_penalties = (
|
|
||||||
torch.tensor(
|
|
||||||
data=[
|
|
||||||
req.sampling_params.repetition_penalty
|
|
||||||
for req in self.orchestrator.reqs()
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.orchestrator.device,
|
|
||||||
)
|
|
||||||
.unsqueeze_(1)
|
|
||||||
.expand_as(self.cumulated_repetition_penalties)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _teardown(self):
|
|
||||||
self.repetition_penalties = None
|
|
||||||
self.cumulated_repetition_penalties = None
|
|
||||||
|
|
||||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
|
||||||
mask = input_ids.occurrence_count() > 0
|
|
||||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
|
||||||
|
|
||||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
|
||||||
mask = output_ids.occurrence_count() > 0
|
|
||||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
|
||||||
|
|
||||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
||||||
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
|
||||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
|
||||||
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
|
||||||
indices_tensor_to_keep
|
|
||||||
]
|
|
||||||
|
|
||||||
def _merge(self, their: "BatchedRepetitionPenalizer"):
|
|
||||||
self.repetition_penalties = torch.cat(
|
|
||||||
[self.repetition_penalties, their.repetition_penalties], dim=0
|
|
||||||
)
|
|
||||||
self.cumulated_repetition_penalties = torch.cat(
|
|
||||||
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
66
python/sglang/srt/sampling/penaltylib/presence_penalty.py
Normal file
66
python/sglang/srt/sampling/penaltylib/presence_penalty.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||||
|
BatchedPenalizerOrchestrator,
|
||||||
|
_BatchedPenalizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||||
|
"""
|
||||||
|
Presence penalizer penalizes tokens based on their presence in the output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||||
|
self.orchestrator = orchestrator
|
||||||
|
self._is_prepared = False
|
||||||
|
|
||||||
|
def _is_required(self) -> bool:
|
||||||
|
return any(
|
||||||
|
req.sampling_params.presence_penalty != 0.0
|
||||||
|
for req in self.orchestrator.reqs()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
self.cumulated_presence_penalties = torch.zeros(
|
||||||
|
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.orchestrator.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.presence_penalties = (
|
||||||
|
torch.tensor(
|
||||||
|
data=[
|
||||||
|
req.sampling_params.presence_penalty
|
||||||
|
for req in self.orchestrator.reqs()
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.orchestrator.device,
|
||||||
|
)
|
||||||
|
).unsqueeze_(1)
|
||||||
|
|
||||||
|
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||||
|
self.cumulated_presence_penalties.scatter_(
|
||||||
|
dim=1,
|
||||||
|
index=output_ids.unsqueeze(1),
|
||||||
|
src=self.presence_penalties,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
logits.sub_(self.cumulated_presence_penalties)
|
||||||
|
|
||||||
|
def _filter(self, keep_indices: torch.Tensor):
|
||||||
|
self.presence_penalties = self.presence_penalties[keep_indices]
|
||||||
|
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||||
|
keep_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
def _merge(self, their: "BatchedPresencePenalizer"):
|
||||||
|
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
|
||||||
|
self.presence_penalties = torch.cat(
|
||||||
|
[self.presence_penalties, their.presence_penalties], dim=0
|
||||||
|
)
|
||||||
|
self.cumulated_presence_penalties = torch.cat(
|
||||||
|
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
@@ -9,9 +9,6 @@ import torch
|
|||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
|
|
||||||
apply_scaling_penalties,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -22,49 +19,45 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class SamplingBatchInfo:
|
class SamplingBatchInfo:
|
||||||
# Batched sampling params
|
# Basic batched sampling params
|
||||||
temperatures: torch.Tensor
|
temperatures: torch.Tensor
|
||||||
top_ps: torch.Tensor
|
top_ps: torch.Tensor
|
||||||
top_ks: torch.Tensor
|
top_ks: torch.Tensor
|
||||||
min_ps: torch.Tensor
|
min_ps: torch.Tensor
|
||||||
|
|
||||||
# All requests use greedy sampling
|
# Whether all requests use greedy sampling
|
||||||
is_all_greedy: bool
|
is_all_greedy: bool
|
||||||
|
|
||||||
# Dispatch in CUDA graph
|
# Whether any request needs min_p sampling
|
||||||
need_min_p_sampling: bool
|
need_min_p_sampling: bool
|
||||||
|
|
||||||
# Whether any request has custom logit processor
|
# Masking tensors for grammar-guided structured outputs
|
||||||
has_custom_logit_processor: bool
|
|
||||||
|
|
||||||
# Bias Tensors
|
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
grammars: Optional[List] = None
|
grammars: Optional[List] = None
|
||||||
sampling_info_done: Optional[threading.Event] = None
|
|
||||||
logit_bias: torch.Tensor = None
|
|
||||||
vocab_mask: Optional[torch.Tensor] = None
|
vocab_mask: Optional[torch.Tensor] = None
|
||||||
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||||
|
|
||||||
|
# An event used for overlap schedule
|
||||||
|
sampling_info_done: Optional[threading.Event] = None
|
||||||
|
|
||||||
# Penalizer
|
# Penalizer
|
||||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||||
linear_penalties: Optional[torch.Tensor] = None
|
linear_penalty: torch.Tensor = None
|
||||||
scaling_penalties: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
# Device
|
# Whether any request has custom logit processor
|
||||||
device: str = "cuda"
|
has_custom_logit_processor: bool = False
|
||||||
|
# Custom parameters
|
||||||
# Custom Parameters
|
|
||||||
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
||||||
|
# Custom logit processor
|
||||||
# Custom Logit Processor
|
|
||||||
custom_logit_processor: Optional[
|
custom_logit_processor: Optional[
|
||||||
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
|
# Device
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||||
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
|
||||||
):
|
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
device = batch.device
|
device = batch.device
|
||||||
temperatures = (
|
temperatures = (
|
||||||
@@ -118,106 +111,60 @@ class SamplingBatchInfo:
|
|||||||
merged_custom_logit_processor = None
|
merged_custom_logit_processor = None
|
||||||
custom_params = None
|
custom_params = None
|
||||||
|
|
||||||
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||||
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||||
|
# should not add hefty computation overhead other than simple checks.
|
||||||
|
#
|
||||||
|
# While we can choose not to even create the class instances if they are not required, this
|
||||||
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||||
|
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||||
|
penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
batch=batch,
|
||||||
|
penalizers={
|
||||||
|
penaltylib.BatchedFrequencyPenalizer,
|
||||||
|
penaltylib.BatchedMinNewTokensPenalizer,
|
||||||
|
penaltylib.BatchedPresencePenalizer,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
ret = cls(
|
ret = cls(
|
||||||
temperatures=temperatures,
|
temperatures=temperatures,
|
||||||
top_ps=top_ps,
|
top_ps=top_ps,
|
||||||
top_ks=top_ks,
|
top_ks=top_ks,
|
||||||
min_ps=min_ps,
|
min_ps=min_ps,
|
||||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
|
||||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||||
has_custom_logit_processor=has_custom_logit_processor,
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
device=device,
|
penalizer_orchestrator=penalizer_orchestrator,
|
||||||
|
has_custom_logit_processor=has_custom_logit_processor,
|
||||||
custom_params=custom_params,
|
custom_params=custom_params,
|
||||||
custom_logit_processor=merged_custom_logit_processor,
|
custom_logit_processor=merged_custom_logit_processor,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
|
||||||
|
|
||||||
if enable_overlap_schedule:
|
|
||||||
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
|
|
||||||
# so it is kind of tricky to make it work with overlap scheduler.
|
|
||||||
# It requires correcly updating the penalty logits before the sampling and syncing the events.
|
|
||||||
# We will support them later.
|
|
||||||
penalizers = {
|
|
||||||
penaltylib.BatchedMinNewTokensPenalizer,
|
|
||||||
}
|
|
||||||
if (
|
|
||||||
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
|
|
||||||
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
|
|
||||||
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
|
|
||||||
"when using the default overlap scheduler. They will be ignored. "
|
|
||||||
"Please add `--disable-overlap` when launching the server if you need these features. "
|
|
||||||
"The speed will be slower in that case."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
penalizers = {
|
|
||||||
penaltylib.BatchedFrequencyPenalizer,
|
|
||||||
penaltylib.BatchedMinNewTokensPenalizer,
|
|
||||||
penaltylib.BatchedPresencePenalizer,
|
|
||||||
penaltylib.BatchedRepetitionPenalizer,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
|
||||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
|
||||||
# should not add hefty computation overhead other than simple checks.
|
|
||||||
#
|
|
||||||
# While we choose not to even create the class instances if they are not required, this
|
|
||||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
|
||||||
# handle {filter_batch()} and {merge_batch()} cases as well.
|
|
||||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
batch=batch,
|
|
||||||
device=batch.device,
|
|
||||||
Penalizers=penalizers,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle logit bias but only allocate when needed
|
|
||||||
ret.logit_bias = None
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.temperatures)
|
return len(self.temperatures)
|
||||||
|
|
||||||
def update_penalties(self):
|
|
||||||
self.scaling_penalties = None
|
|
||||||
self.linear_penalties = None
|
|
||||||
|
|
||||||
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
|
||||||
if not penalizer.is_prepared():
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
|
||||||
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
|
||||||
else:
|
|
||||||
if self.linear_penalties is None:
|
|
||||||
bs = self.penalizer_orchestrator.batch.batch_size()
|
|
||||||
self.linear_penalties = torch.zeros(
|
|
||||||
(bs, self.vocab_size),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
|
||||||
|
|
||||||
def update_regex_vocab_mask(self):
|
def update_regex_vocab_mask(self):
|
||||||
if not self.grammars:
|
if not self.grammars:
|
||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
self.apply_mask = None
|
self.apply_mask_func = None
|
||||||
return
|
return
|
||||||
|
|
||||||
# find a grammar from the list
|
# Find a grammar from the list
|
||||||
first_grammar = next(grammar for grammar in self.grammars if grammar)
|
first_grammar = next(grammar for grammar in self.grammars if grammar)
|
||||||
|
|
||||||
# maybe we can reuse the existing mask?
|
# TODO(lianmin): Maybe we can reuse the existing mask?
|
||||||
self.vocab_mask = first_grammar.allocate_vocab_mask(
|
self.vocab_mask = first_grammar.allocate_vocab_mask(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
batch_size=len(self.temperatures),
|
batch_size=len(self.temperatures),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
|
self.apply_mask_func = (
|
||||||
|
first_grammar.apply_vocab_mask
|
||||||
|
) # force to use static method
|
||||||
|
|
||||||
# Apply the mask
|
# Apply the mask
|
||||||
for i, grammar in enumerate(self.grammars):
|
for i, grammar in enumerate(self.grammars):
|
||||||
@@ -227,35 +174,56 @@ class SamplingBatchInfo:
|
|||||||
# Move the mask to the device if needed
|
# Move the mask to the device if needed
|
||||||
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
|
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def update_penalties(self):
|
||||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
if self.penalizer_orchestrator.is_required:
|
||||||
|
self.linear_penalty = torch.zeros(
|
||||||
|
(len(self.temperatures), self.vocab_size),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.temperatures.device,
|
||||||
|
)
|
||||||
|
self.penalizer_orchestrator.apply(self.linear_penalty)
|
||||||
|
else:
|
||||||
|
self.linear_penalty = None
|
||||||
|
|
||||||
|
def apply_logits_bias(self, logits: torch.Tensor):
|
||||||
|
if self.linear_penalty is not None:
|
||||||
|
# Used in the overlap mode
|
||||||
|
logits.add_(self.linear_penalty)
|
||||||
|
|
||||||
|
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
|
||||||
|
# Used in the non-overlap mode
|
||||||
|
self.penalizer_orchestrator.apply(logits)
|
||||||
|
|
||||||
|
if self.vocab_mask is not None:
|
||||||
|
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
|
||||||
|
|
||||||
|
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
|
||||||
|
self.penalizer_orchestrator.filter(keep_indices_device)
|
||||||
|
|
||||||
if self.has_custom_logit_processor:
|
if self.has_custom_logit_processor:
|
||||||
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
|
self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device)
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
"temperatures",
|
"temperatures",
|
||||||
"top_ps",
|
"top_ps",
|
||||||
"top_ks",
|
"top_ks",
|
||||||
"min_ps",
|
"min_ps",
|
||||||
"logit_bias",
|
|
||||||
]:
|
]:
|
||||||
value = getattr(self, item, None)
|
value = getattr(self, item, None)
|
||||||
if value is not None: # logit_bias can be None
|
setattr(self, item, value[keep_indices_device])
|
||||||
setattr(self, item, value[new_indices])
|
|
||||||
|
|
||||||
def _filter_batch_custom_logit_processor(
|
def _filter_batch_custom_logit_processor(
|
||||||
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
self, keep_indices: List[int], keep_indices_device: torch.Tensor
|
||||||
):
|
):
|
||||||
"""Filter the custom logit processor and custom params"""
|
"""Filter the custom logit processor and custom params"""
|
||||||
|
|
||||||
self.custom_logit_processor = {
|
self.custom_logit_processor = {
|
||||||
k: (p, mask[new_indices])
|
k: (p, mask[keep_indices_device])
|
||||||
for k, (p, mask) in self.custom_logit_processor.items()
|
for k, (p, mask) in self.custom_logit_processor.items()
|
||||||
if any(
|
if torch.any(
|
||||||
mask[new_indices]
|
mask[keep_indices_device]
|
||||||
) # ignore the custom logit processor whose mask is all False
|
) # ignore the custom logit processor whose mask is all False
|
||||||
}
|
}
|
||||||
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
self.custom_params = [self.custom_params[i] for i in keep_indices]
|
||||||
|
|
||||||
# If the custom logit processor is an empty dict, set the flag to False,
|
# If the custom logit processor is an empty dict, set the flag to False,
|
||||||
# and set the custom logit processor and custom params to None.
|
# and set the custom logit processor and custom params to None.
|
||||||
@@ -264,31 +232,6 @@ class SamplingBatchInfo:
|
|||||||
self.custom_params = None
|
self.custom_params = None
|
||||||
self.has_custom_logit_processor = False
|
self.has_custom_logit_processor = False
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def merge_bias_tensor(
|
|
||||||
lhs: torch.Tensor,
|
|
||||||
rhs: torch.Tensor,
|
|
||||||
bs1: int,
|
|
||||||
bs2: int,
|
|
||||||
device: str,
|
|
||||||
default: int = 0,
|
|
||||||
):
|
|
||||||
# bias tensor can be None
|
|
||||||
if lhs is not None or rhs is not None:
|
|
||||||
shape, dtype = None, None
|
|
||||||
if lhs is not None:
|
|
||||||
shape, dtype = lhs.shape[1:], lhs.dtype
|
|
||||||
else:
|
|
||||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
|
||||||
with torch.dtype(dtype):
|
|
||||||
if lhs is None:
|
|
||||||
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
|
|
||||||
if rhs is None:
|
|
||||||
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
|
|
||||||
return torch.cat([lhs, rhs])
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def merge_custom_logit_processor(
|
def merge_custom_logit_processor(
|
||||||
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
||||||
@@ -332,11 +275,6 @@ class SamplingBatchInfo:
|
|||||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||||
|
|
||||||
# Merge the logit bias tensor
|
|
||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge the custom logit processors and custom params lists
|
# Merge the custom logit processors and custom params lists
|
||||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||||
# Merge the custom logit processors
|
# Merge the custom logit processors
|
||||||
@@ -370,22 +308,5 @@ class SamplingBatchInfo:
|
|||||||
other_val = getattr(other, item, None)
|
other_val = getattr(other, item, None)
|
||||||
setattr(self, item, torch.concat([self_val, other_val]))
|
setattr(self, item, torch.concat([self_val, other_val]))
|
||||||
|
|
||||||
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
self.is_all_greedy |= other.is_all_greedy
|
||||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||||
|
|
||||||
def apply_logits_bias(self, logits: torch.Tensor):
|
|
||||||
# Apply logit_bias
|
|
||||||
if self.logit_bias is not None:
|
|
||||||
logits.add_(self.logit_bias)
|
|
||||||
|
|
||||||
# min-token, presence, frequency
|
|
||||||
if self.linear_penalties is not None:
|
|
||||||
logits.add_(self.linear_penalties)
|
|
||||||
|
|
||||||
# repetition
|
|
||||||
if self.scaling_penalties is not None:
|
|
||||||
apply_scaling_penalties(logits, self.scaling_penalties)
|
|
||||||
|
|
||||||
# Apply regex vocab_mask
|
|
||||||
if self.vocab_mask is not None:
|
|
||||||
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
|
||||||
|
|||||||
@@ -15,15 +15,21 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
create_checksum,
|
||||||
get_amdgpu_memory_capacity,
|
get_amdgpu_memory_capacity,
|
||||||
get_hpu_memory_capacity,
|
get_hpu_memory_capacity,
|
||||||
get_nvgpu_memory_capacity,
|
get_nvgpu_memory_capacity,
|
||||||
@@ -43,12 +49,13 @@ class ServerArgs:
|
|||||||
model_path: str
|
model_path: str
|
||||||
tokenizer_path: Optional[str] = None
|
tokenizer_path: Optional[str] = None
|
||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: str = "auto"
|
||||||
|
skip_tokenizer_init: bool = False
|
||||||
load_format: str = "auto"
|
load_format: str = "auto"
|
||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = False
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
kv_cache_dtype: str = "auto"
|
kv_cache_dtype: str = "auto"
|
||||||
quantization_param_path: nullable_str = None
|
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[str] = None
|
||||||
|
quantization_param_path: nullable_str = None
|
||||||
context_length: Optional[int] = None
|
context_length: Optional[int] = None
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
served_model_name: Optional[str] = None
|
served_model_name: Optional[str] = None
|
||||||
@@ -67,7 +74,7 @@ class ServerArgs:
|
|||||||
max_total_tokens: Optional[int] = None
|
max_total_tokens: Optional[int] = None
|
||||||
chunked_prefill_size: Optional[int] = None
|
chunked_prefill_size: Optional[int] = None
|
||||||
max_prefill_tokens: int = 16384
|
max_prefill_tokens: int = 16384
|
||||||
schedule_policy: str = "lpm"
|
schedule_policy: str = "fcfs"
|
||||||
schedule_conservativeness: float = 1.0
|
schedule_conservativeness: float = 1.0
|
||||||
cpu_offload_gb: int = 0
|
cpu_offload_gb: int = 0
|
||||||
prefill_only_one_req: bool = False
|
prefill_only_one_req: bool = False
|
||||||
@@ -88,6 +95,7 @@ class ServerArgs:
|
|||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
log_level_http: Optional[str] = None
|
log_level_http: Optional[str] = None
|
||||||
log_requests: bool = False
|
log_requests: bool = False
|
||||||
|
log_requests_level: int = 0
|
||||||
show_time_cost: bool = False
|
show_time_cost: bool = False
|
||||||
enable_metrics: bool = False
|
enable_metrics: bool = False
|
||||||
decode_log_interval: int = 40
|
decode_log_interval: int = 40
|
||||||
@@ -123,11 +131,13 @@ class ServerArgs:
|
|||||||
grammar_backend: Optional[str] = "outlines"
|
grammar_backend: Optional[str] = "outlines"
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
speculative_draft_model_path: Optional[str] = None
|
|
||||||
speculative_algorithm: Optional[str] = None
|
speculative_algorithm: Optional[str] = None
|
||||||
|
speculative_draft_model_path: Optional[str] = None
|
||||||
speculative_num_steps: int = 5
|
speculative_num_steps: int = 5
|
||||||
speculative_eagle_topk: int = 8
|
speculative_eagle_topk: int = 4
|
||||||
speculative_num_draft_tokens: int = 64
|
speculative_num_draft_tokens: int = 8
|
||||||
|
speculative_accept_threshold_single: float = 1.0
|
||||||
|
speculative_accept_threshold_acc: float = 1.0
|
||||||
speculative_token_map: Optional[str] = None
|
speculative_token_map: Optional[str] = None
|
||||||
|
|
||||||
# Double Sparsity
|
# Double Sparsity
|
||||||
@@ -169,6 +179,12 @@ class ServerArgs:
|
|||||||
enable_hierarchical_cache: bool = False
|
enable_hierarchical_cache: bool = False
|
||||||
enable_flashinfer_mla: bool = False
|
enable_flashinfer_mla: bool = False
|
||||||
flashinfer_mla_disable_ragged: bool = False
|
flashinfer_mla_disable_ragged: bool = False
|
||||||
|
warmups: Optional[str] = None
|
||||||
|
|
||||||
|
# Debug tensor dumps
|
||||||
|
debug_tensor_dump_output_folder: Optional[str] = None
|
||||||
|
debug_tensor_dump_input_file: Optional[str] = None
|
||||||
|
debug_tensor_dump_inject: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
@@ -266,10 +282,10 @@ class ServerArgs:
|
|||||||
self.speculative_algorithm == "EAGLE"
|
self.speculative_algorithm == "EAGLE"
|
||||||
or self.speculative_algorithm == "NEXTN"
|
or self.speculative_algorithm == "NEXTN"
|
||||||
):
|
):
|
||||||
|
self.disable_overlap_schedule = True
|
||||||
self.prefill_only_one_req = True
|
self.prefill_only_one_req = True
|
||||||
self.disable_cuda_graph_padding = True
|
self.disable_cuda_graph_padding = True
|
||||||
self.disable_radix_cache = True
|
self.disable_radix_cache = True
|
||||||
self.disable_overlap_schedule = True
|
|
||||||
self.chunked_prefill_size = -1
|
self.chunked_prefill_size = -1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
|
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
|
||||||
@@ -377,15 +393,6 @@ class ServerArgs:
|
|||||||
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
||||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--quantization-param-path",
|
|
||||||
type=nullable_str,
|
|
||||||
default=None,
|
|
||||||
help="Path to the JSON file containing the KV cache "
|
|
||||||
"scaling factors. This should generally be supplied, when "
|
|
||||||
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
|
||||||
"default to 1.0, which may cause accuracy issues. ",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--quantization",
|
"--quantization",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -404,6 +411,15 @@ class ServerArgs:
|
|||||||
],
|
],
|
||||||
help="The quantization method.",
|
help="The quantization method.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantization-param-path",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="Path to the JSON file containing the KV cache "
|
||||||
|
"scaling factors. This should generally be supplied, when "
|
||||||
|
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
||||||
|
"default to 1.0, which may cause accuracy issues. ",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-length",
|
"--context-length",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -578,7 +594,14 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-requests",
|
"--log-requests",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Log the inputs and outputs of all requests.",
|
help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-requests-level",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
|
||||||
|
choices=[0, 1, 2],
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--show-time-cost",
|
"--show-time-cost",
|
||||||
@@ -742,16 +765,28 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-eagle-topk",
|
"--speculative-eagle-topk",
|
||||||
type=int,
|
type=int,
|
||||||
help="The number of token sampled from draft model in eagle2 each step.",
|
help="The number of tokens sampled from the draft model in eagle2 each step.",
|
||||||
choices=[1, 2, 4, 8],
|
choices=[1, 2, 4, 8],
|
||||||
default=ServerArgs.speculative_eagle_topk,
|
default=ServerArgs.speculative_eagle_topk,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-num-draft-tokens",
|
"--speculative-num-draft-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
help="The number of token sampled from draft model in Speculative Decoding.",
|
help="The number of tokens sampled from the draft model in Speculative Decoding.",
|
||||||
default=ServerArgs.speculative_num_draft_tokens,
|
default=ServerArgs.speculative_num_draft_tokens,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative-accept-threshold-single",
|
||||||
|
type=float,
|
||||||
|
help="Accept a draft token if its probability in the target model is greater than this threshold.",
|
||||||
|
default=ServerArgs.speculative_accept_threshold_single,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speculative-accept-threshold-acc",
|
||||||
|
type=float,
|
||||||
|
help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
|
||||||
|
default=ServerArgs.speculative_accept_threshold_acc,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-token-map",
|
"--speculative-token-map",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -949,6 +984,35 @@ class ServerArgs:
|
|||||||
help="Enable hierarchical cache",
|
help="Enable hierarchical cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Server warmups
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmups",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
|
||||||
|
"will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Debug tensor dumps
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug-tensor-dump-output-folder",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.debug_tensor_dump_output_folder,
|
||||||
|
help="The output folder for dumping tensors.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug-tensor-dump-input-file",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.debug_tensor_dump_input_file,
|
||||||
|
help="The input filename for dumping tensors",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug-tensor-dump-inject",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.debug_tensor_dump_inject,
|
||||||
|
help="Inject the outputs from jax as the input of every layer.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
|
|||||||
@@ -32,13 +32,15 @@ import socket
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from multiprocessing import Pool
|
||||||
from multiprocessing.reduction import ForkingPickler
|
from multiprocessing.reduction import ForkingPickler
|
||||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import psutil
|
import psutil
|
||||||
@@ -480,6 +482,10 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
|
|||||||
|
|
||||||
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
|
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
|
||||||
"""Kill the process and all its child processes."""
|
"""Kill the process and all its child processes."""
|
||||||
|
# Remove sigchld handler to avoid spammy logs.
|
||||||
|
if threading.current_thread() is threading.main_thread():
|
||||||
|
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
|
||||||
|
|
||||||
if parent_pid is None:
|
if parent_pid is None:
|
||||||
parent_pid = os.getpid()
|
parent_pid = os.getpid()
|
||||||
include_parent = False
|
include_parent = False
|
||||||
@@ -499,9 +505,6 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if include_parent:
|
if include_parent:
|
||||||
if parent_pid == os.getpid():
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
itself.kill()
|
itself.kill()
|
||||||
|
|
||||||
@@ -1215,7 +1218,11 @@ def cuda_device_count_stateless() -> int:
|
|||||||
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||||
|
|
||||||
|
|
||||||
def dataclass_to_string_truncated(data, max_length=2048):
|
def dataclass_to_string_truncated(
|
||||||
|
data, max_length=2048, skip_names: Optional[Set[str]] = None
|
||||||
|
):
|
||||||
|
if skip_names is None:
|
||||||
|
skip_names = set()
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
if len(data) > max_length:
|
if len(data) > max_length:
|
||||||
half_length = max_length // 2
|
half_length = max_length // 2
|
||||||
@@ -1234,6 +1241,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
|||||||
+ ", ".join(
|
+ ", ".join(
|
||||||
f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
|
f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
|
||||||
for k, v in data.items()
|
for k, v in data.items()
|
||||||
|
if k not in skip_names
|
||||||
)
|
)
|
||||||
+ "}"
|
+ "}"
|
||||||
)
|
)
|
||||||
@@ -1244,6 +1252,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
|||||||
+ ", ".join(
|
+ ", ".join(
|
||||||
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
|
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
|
||||||
for f in fields
|
for f in fields
|
||||||
|
if f.name not in skip_names
|
||||||
)
|
)
|
||||||
+ ")"
|
+ ")"
|
||||||
)
|
)
|
||||||
@@ -1322,9 +1331,9 @@ def pyspy_dump_schedulers():
|
|||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
cmd, shell=True, capture_output=True, text=True, check=True
|
cmd, shell=True, capture_output=True, text=True, check=True
|
||||||
)
|
)
|
||||||
logger.info(f"Profile for PID {pid}:\n{result.stdout}")
|
logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}")
|
||||||
except subprocess.CalledProcessError as e:
|
except subprocess.CalledProcessError as e:
|
||||||
logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
|
logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}")
|
||||||
|
|
||||||
|
|
||||||
def kill_itself_when_parent_died():
|
def kill_itself_when_parent_died():
|
||||||
@@ -1448,6 +1457,10 @@ def launch_dummy_health_check_server(host, port):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_checksum(directory: str):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def set_cuda_arch():
|
def set_cuda_arch():
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
|
|||||||
47
python/sglang/srt/warmup.py
Normal file
47
python/sglang/srt/warmup.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
_warmup_registry = {}
|
||||||
|
|
||||||
|
|
||||||
|
def warmup(name: str) -> callable:
|
||||||
|
def decorator(fn: callable):
|
||||||
|
_warmup_registry[name] = fn
|
||||||
|
return fn
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager):
|
||||||
|
for warmup_name in warmup_names:
|
||||||
|
if warmup_name not in _warmup_registry:
|
||||||
|
logger.warning(f"Could not find custom warmup {warmup_name}")
|
||||||
|
continue
|
||||||
|
logger.info(f"Running warmup {warmup_name}")
|
||||||
|
await _warmup_registry[warmup_name](tokenizer_manager)
|
||||||
|
|
||||||
|
|
||||||
|
@warmup("voice_chat")
|
||||||
|
async def voice_chat(tokenizer_manager: TokenizerManager):
|
||||||
|
# this warms up the fused_moe triton kernels and caches them
|
||||||
|
# if we don't do this we break real time inference for voice chat
|
||||||
|
for i in tqdm.trange(1, 512):
|
||||||
|
size = i * 4
|
||||||
|
generate_req_input = GenerateReqInput(
|
||||||
|
input_ids=(np.random.randint(2**16, size=[size])).tolist(),
|
||||||
|
sampling_params={
|
||||||
|
"max_new_tokens": 30,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"stop_token_ids": [1],
|
||||||
|
"min_p": 0.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -56,6 +56,13 @@ def get_top_logprobs(logits, k):
|
|||||||
return logprobs
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_ids_logprobs(logits, token_ids):
|
||||||
|
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
|
del logits
|
||||||
|
logprobs = logprobs[..., token_ids]
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from sentence_transformers.util import is_sentence_transformer_model
|
from sentence_transformers.util import is_sentence_transformer_model
|
||||||
@@ -84,8 +91,13 @@ class ModelOutput:
|
|||||||
output_ids: List[int] = None
|
output_ids: List[int] = None
|
||||||
top_input_logprobs: List[torch.Tensor] = None
|
top_input_logprobs: List[torch.Tensor] = None
|
||||||
top_output_logprobs: List[torch.Tensor] = None
|
top_output_logprobs: List[torch.Tensor] = None
|
||||||
|
top_output_logprob_idx: List[List[int]] = None
|
||||||
embed_logits: List[torch.Tensor] = None
|
embed_logits: List[torch.Tensor] = None
|
||||||
scores: List[float] = None
|
scores: List[float] = None
|
||||||
|
input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
|
||||||
|
output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
|
||||||
|
token_ids_input_logprobs: List[torch.Tensor] = None
|
||||||
|
token_ids_output_logprobs: List[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
class HFRunner:
|
class HFRunner:
|
||||||
@@ -157,7 +169,7 @@ class HFRunner:
|
|||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
while True:
|
while True:
|
||||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
|
||||||
if lora_paths is not None:
|
if lora_paths is not None:
|
||||||
assert len(prompts) == len(lora_paths)
|
assert len(prompts) == len(lora_paths)
|
||||||
|
|
||||||
@@ -165,16 +177,16 @@ class HFRunner:
|
|||||||
if self.model_type == "generation":
|
if self.model_type == "generation":
|
||||||
out_queue.put(
|
out_queue.put(
|
||||||
self.forward_generation_raw(
|
self.forward_generation_raw(
|
||||||
|
base_model=self.base_model,
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
base_model=self.base_model,
|
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
output_str_only=self.output_str_only,
|
output_str_only=self.output_str_only,
|
||||||
|
token_ids_logprob=token_ids_logprob,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.model_type == "embedding":
|
elif self.model_type == "embedding":
|
||||||
assert not self.output_str_only
|
assert not self.output_str_only
|
||||||
logits = self.model.encode(prompts).tolist()
|
logits = self.model.encode(prompts).tolist()
|
||||||
@@ -199,10 +211,11 @@ class HFRunner:
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
max_new_tokens=8,
|
max_new_tokens: int = 8,
|
||||||
lora_paths=None,
|
lora_paths: Optional[List[str]] = None,
|
||||||
|
token_ids_logprob: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.in_queue.put((prompts, max_new_tokens, lora_paths))
|
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
|
||||||
return self.out_queue.get()
|
return self.out_queue.get()
|
||||||
|
|
||||||
def terminate(self):
|
def terminate(self):
|
||||||
@@ -218,17 +231,24 @@ class HFRunner:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward_generation_raw(
|
def forward_generation_raw(
|
||||||
prompts: Union[List[str], List[torch.Tensor]],
|
|
||||||
max_new_tokens,
|
|
||||||
base_model,
|
base_model,
|
||||||
|
prompts: Union[List[str], List[torch.Tensor]],
|
||||||
|
max_new_tokens: int,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
lora_paths,
|
|
||||||
torch_dtype: torch.dtype,
|
torch_dtype: torch.dtype,
|
||||||
output_str_only: bool,
|
lora_paths: Optional[List[str]] = None,
|
||||||
|
output_str_only: bool = False,
|
||||||
|
token_ids_logprob: Optional[int] = None,
|
||||||
) -> ModelOutput:
|
) -> ModelOutput:
|
||||||
output_strs = []
|
output_strs = []
|
||||||
top_input_logprobs = []
|
top_input_logprobs = []
|
||||||
top_output_logprobs = []
|
top_output_logprobs = []
|
||||||
|
if token_ids_logprob is not None:
|
||||||
|
token_ids_input_logprobs = []
|
||||||
|
token_ids_output_logprobs = []
|
||||||
|
else:
|
||||||
|
token_ids_input_logprobs = token_ids_output_logprobs = None
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
for i, p in enumerate(prompts):
|
||||||
if isinstance(p, str):
|
if isinstance(p, str):
|
||||||
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
|
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
|
||||||
@@ -275,18 +295,33 @@ class HFRunner:
|
|||||||
for logits in outputs.scores
|
for logits in outputs.scores
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
if token_ids_logprob is not None:
|
||||||
|
token_ids_output_logprobs.append(
|
||||||
|
[
|
||||||
|
get_token_ids_logprobs(
|
||||||
|
logits[0], token_ids_logprob
|
||||||
|
).tolist()
|
||||||
|
for logits in outputs.scores
|
||||||
|
]
|
||||||
|
)
|
||||||
del outputs
|
del outputs
|
||||||
|
|
||||||
input_logits = model.forward(input_ids).logits[0]
|
input_logits = model.forward(input_ids).logits[0]
|
||||||
top_input_logprobs.append(
|
top_input_logprobs.append(
|
||||||
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||||
)
|
)
|
||||||
|
if token_ids_logprob is not None:
|
||||||
|
token_ids_input_logprobs.append(
|
||||||
|
get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
|
||||||
|
)
|
||||||
del input_logits
|
del input_logits
|
||||||
|
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
output_strs=output_strs,
|
output_strs=output_strs,
|
||||||
top_input_logprobs=top_input_logprobs,
|
top_input_logprobs=top_input_logprobs,
|
||||||
top_output_logprobs=top_output_logprobs,
|
top_output_logprobs=top_output_logprobs,
|
||||||
|
token_ids_input_logprobs=token_ids_input_logprobs,
|
||||||
|
token_ids_output_logprobs=token_ids_output_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -303,11 +338,31 @@ class SRTRunner:
|
|||||||
lora_backend: str = "triton",
|
lora_backend: str = "triton",
|
||||||
disable_cuda_graph: bool = False,
|
disable_cuda_graph: bool = False,
|
||||||
disable_radix_cache: bool = False,
|
disable_radix_cache: bool = False,
|
||||||
|
chunked_prefill_size: Optional[int] = None,
|
||||||
|
dp_size: int = 1,
|
||||||
|
tokenizer_path: Optional[str] = None,
|
||||||
|
enable_ep_moe: bool = False,
|
||||||
mem_fraction_static: float = 0.65,
|
mem_fraction_static: float = 0.65,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
speculative_draft_model_path: Optional[str] = None,
|
||||||
|
speculative_algorithm: Optional[str] = None,
|
||||||
|
speculative_num_steps: Optional[int] = None,
|
||||||
|
speculative_eagle_topk: Optional[int] = None,
|
||||||
|
speculative_num_draft_tokens: Optional[int] = None,
|
||||||
|
disable_overlap_schedule: bool = False,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.is_generation = model_type == "generation"
|
self.is_generation = model_type == "generation"
|
||||||
|
enable_dp_attention = dp_size > 1
|
||||||
|
|
||||||
|
spec_kwargs = {}
|
||||||
|
if speculative_draft_model_path:
|
||||||
|
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
||||||
|
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
||||||
|
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
||||||
|
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
||||||
|
spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
|
||||||
|
|
||||||
self.engine = Engine(
|
self.engine = Engine(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
@@ -321,21 +376,41 @@ class SRTRunner:
|
|||||||
lora_backend=lora_backend,
|
lora_backend=lora_backend,
|
||||||
disable_cuda_graph=disable_cuda_graph,
|
disable_cuda_graph=disable_cuda_graph,
|
||||||
disable_radix_cache=disable_radix_cache,
|
disable_radix_cache=disable_radix_cache,
|
||||||
|
chunked_prefill_size=chunked_prefill_size,
|
||||||
|
enable_dp_attention=enable_dp_attention,
|
||||||
|
dp_size=dp_size,
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
|
enable_ep_moe=enable_ep_moe,
|
||||||
|
disable_overlap_schedule=disable_overlap_schedule,
|
||||||
|
cuda_graph_max_bs=4,
|
||||||
|
**spec_kwargs,
|
||||||
)
|
)
|
||||||
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
|
|
||||||
|
if tokenizer_path is None:
|
||||||
|
self.tokenizer = get_tokenizer(
|
||||||
|
model_path, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
max_new_tokens=8,
|
max_new_tokens: int = 8,
|
||||||
lora_paths=None,
|
lora_paths: Optional[List[str]] = None,
|
||||||
|
logprob_start_len: int = 0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
token_ids_logprob: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
return self.forward_generation_raw(
|
return self.forward_generation_raw(
|
||||||
|
engine=self.engine,
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
engine=self.engine,
|
logprob_start_len=logprob_start_len,
|
||||||
|
top_k=top_k,
|
||||||
|
token_ids_logprob=token_ids_logprob,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.engine.encode(prompts)
|
response = self.engine.encode(prompts)
|
||||||
@@ -358,10 +433,10 @@ class SRTRunner:
|
|||||||
"""
|
"""
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
return self.batch_forward_generation_raw(
|
return self.batch_forward_generation_raw(
|
||||||
|
engine=self.engine,
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
engine=self.engine,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.engine.encode(prompts)
|
response = self.engine.encode(prompts)
|
||||||
@@ -381,24 +456,43 @@ class SRTRunner:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward_generation_raw(
|
def forward_generation_raw(
|
||||||
|
engine: Engine,
|
||||||
prompts: Union[List[str], List[torch.Tensor]],
|
prompts: Union[List[str], List[torch.Tensor]],
|
||||||
max_new_tokens,
|
max_new_tokens: int = 8,
|
||||||
lora_paths,
|
lora_paths: Optional[List[str]] = None,
|
||||||
engine,
|
logprob_start_len: int = 0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
token_ids_logprob: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
# the return value contains logprobs from prefill
|
# the return value contains logprobs from prefill
|
||||||
output_strs = []
|
output_strs = []
|
||||||
|
output_ids = []
|
||||||
|
# Input logprobs. Note that the last item in input logprob is equivalent to
|
||||||
|
# the first item in the output logprob.
|
||||||
top_input_logprobs = []
|
top_input_logprobs = []
|
||||||
|
input_token_logprobs_lst = []
|
||||||
top_output_logprobs = []
|
top_output_logprobs = []
|
||||||
|
output_token_logprobs_lst = []
|
||||||
|
top_output_logprob_idx = []
|
||||||
|
if token_ids_logprob is not None:
|
||||||
|
token_ids_input_logprobs = []
|
||||||
|
token_ids_output_logprobs = []
|
||||||
|
else:
|
||||||
|
token_ids_input_logprobs = token_ids_output_logprobs = None
|
||||||
|
|
||||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||||
|
if top_k:
|
||||||
|
sampling_params["top_k"] = top_k
|
||||||
|
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
response = engine.generate(
|
response = engine.generate(
|
||||||
prompt,
|
prompt,
|
||||||
lora_path=lora_paths[i] if lora_paths else None,
|
lora_path=lora_paths[i] if lora_paths else None,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
return_logprob=True,
|
return_logprob=True,
|
||||||
logprob_start_len=0,
|
logprob_start_len=logprob_start_len,
|
||||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||||
|
token_ids_logprob=token_ids_logprob,
|
||||||
)
|
)
|
||||||
text = response["text"]
|
text = response["text"]
|
||||||
|
|
||||||
@@ -408,12 +502,36 @@ class SRTRunner:
|
|||||||
"Received an empty text response. Please verify your input or model configuration."
|
"Received an empty text response. Please verify your input or model configuration."
|
||||||
)
|
)
|
||||||
output_strs.append(text)
|
output_strs.append(text)
|
||||||
|
# output_ids.append(response["output_ids"])
|
||||||
|
|
||||||
|
input_token_logprobs = response["meta_info"]["input_token_logprobs"]
|
||||||
|
output_token_logprobs = response["meta_info"]["output_token_logprobs"]
|
||||||
|
# print(i, input_token_logprobs)
|
||||||
|
# print(i, output_token_logprobs)
|
||||||
|
logprobs = response["meta_info"]["input_top_logprobs"]
|
||||||
|
if token_ids_logprob is not None:
|
||||||
|
input_token_ids_logprobs = response["meta_info"][
|
||||||
|
"input_token_ids_logprobs"
|
||||||
|
][1:]
|
||||||
|
else:
|
||||||
|
input_token_ids_logprobs = None
|
||||||
|
|
||||||
|
num_prompt_tokens = response["meta_info"]["prompt_tokens"]
|
||||||
|
assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
|
||||||
|
assert len(logprobs) == num_prompt_tokens - logprob_start_len
|
||||||
|
|
||||||
|
# The first token logprob has no meaning in sglang.
|
||||||
|
input_token_logprobs = input_token_logprobs[1:]
|
||||||
|
logprobs = logprobs[1:]
|
||||||
|
assert len(input_token_logprobs) == len(logprobs)
|
||||||
|
|
||||||
|
input_token_logprobs_lst.append(
|
||||||
|
input_token_logprobs + [output_token_logprobs[0]]
|
||||||
|
)
|
||||||
|
output_token_logprobs_lst.append(output_token_logprobs)
|
||||||
|
|
||||||
top_input_logprobs.append(
|
top_input_logprobs.append(
|
||||||
[
|
[[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
|
||||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
|
||||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
|
||||||
]
|
|
||||||
+ [
|
+ [
|
||||||
[
|
[
|
||||||
tup[0]
|
tup[0]
|
||||||
@@ -429,11 +547,41 @@ class SRTRunner:
|
|||||||
for x in response["meta_info"]["output_top_logprobs"]
|
for x in response["meta_info"]["output_top_logprobs"]
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
top_output_logprob_idx.append(
|
||||||
|
[
|
||||||
|
[tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||||
|
for x in response["meta_info"]["output_top_logprobs"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if token_ids_logprob is not None:
|
||||||
|
token_ids_input_logprobs.append(
|
||||||
|
[[tup[0] for tup in x] for x in input_token_ids_logprobs]
|
||||||
|
+ [
|
||||||
|
[
|
||||||
|
tup[0]
|
||||||
|
for tup in response["meta_info"][
|
||||||
|
"output_token_ids_logprobs"
|
||||||
|
][0]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
token_ids_output_logprobs.append(
|
||||||
|
[
|
||||||
|
[tup[0] for tup in x]
|
||||||
|
for x in response["meta_info"]["output_token_ids_logprobs"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
output_strs=output_strs,
|
output_strs=output_strs,
|
||||||
|
output_ids=output_ids,
|
||||||
top_input_logprobs=top_input_logprobs,
|
top_input_logprobs=top_input_logprobs,
|
||||||
top_output_logprobs=top_output_logprobs,
|
top_output_logprobs=top_output_logprobs,
|
||||||
|
input_token_logprobs_lst=input_token_logprobs_lst,
|
||||||
|
output_token_logprobs_lst=output_token_logprobs_lst,
|
||||||
|
top_output_logprob_idx=top_output_logprob_idx,
|
||||||
|
token_ids_input_logprobs=token_ids_input_logprobs,
|
||||||
|
token_ids_output_logprobs=token_ids_output_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
88
python/sglang/test/send_one.py
Normal file
88
python/sglang/test/send_one.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
Run one test prompt.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 -m sglang.test.send_one
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def send_one_prompt(args):
|
||||||
|
if args.image:
|
||||||
|
args.prompt = (
|
||||||
|
"Human: Describe this image in a very short sentence.\n\nAssistant:"
|
||||||
|
)
|
||||||
|
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
||||||
|
else:
|
||||||
|
image_data = None
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:30000/generate",
|
||||||
|
json={
|
||||||
|
"text": args.prompt,
|
||||||
|
"image_data": image_data,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"max_new_tokens": args.max_new_tokens,
|
||||||
|
"frequency_penalty": args.frequency_penalty,
|
||||||
|
"presence_penalty": args.presence_penalty,
|
||||||
|
},
|
||||||
|
"return_logprob": args.return_logprob,
|
||||||
|
"stream": args.stream,
|
||||||
|
},
|
||||||
|
stream=args.stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.stream:
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
ret = json.loads(chunk[5:].strip("\n"))
|
||||||
|
else:
|
||||||
|
ret = response.json()
|
||||||
|
|
||||||
|
latency = ret["meta_info"]["e2e_latency"]
|
||||||
|
|
||||||
|
if "spec_verify_ct" in ret["meta_info"]:
|
||||||
|
acc_length = (
|
||||||
|
ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
acc_length = 1.0
|
||||||
|
|
||||||
|
speed = ret["meta_info"]["completion_tokens"] / latency
|
||||||
|
|
||||||
|
print(ret["text"])
|
||||||
|
print()
|
||||||
|
print(f"{acc_length=:.2f}")
|
||||||
|
print(f"{speed=:.2f} token/s")
|
||||||
|
|
||||||
|
return acc_length, speed
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--temperature", type=float, default=0.0)
|
||||||
|
parser.add_argument("--max-new-tokens", type=int, default=512)
|
||||||
|
parser.add_argument("--frequency-penalty", type=float, default=0.0)
|
||||||
|
parser.add_argument("--presence-penalty", type=float, default=0.0)
|
||||||
|
parser.add_argument("--return-logprob", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
parser.add_argument("--stream", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
send_one_prompt(args)
|
||||||
@@ -8,10 +8,11 @@ import random
|
|||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import unittest
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@@ -408,26 +409,49 @@ def popen_launch_server(
|
|||||||
other_args: list[str] = (),
|
other_args: list[str] = (),
|
||||||
env: Optional[dict] = None,
|
env: Optional[dict] = None,
|
||||||
return_stdout_stderr: Optional[tuple] = None,
|
return_stdout_stderr: Optional[tuple] = None,
|
||||||
|
pd_seperated: bool = False,
|
||||||
):
|
):
|
||||||
_, host, port = base_url.split(":")
|
_, host, port = base_url.split(":")
|
||||||
host = host[2:]
|
host = host[2:]
|
||||||
|
|
||||||
|
if pd_seperated:
|
||||||
|
command = "sglang.launch_pd_server"
|
||||||
|
else:
|
||||||
|
command = "sglang.launch_server"
|
||||||
|
|
||||||
command = [
|
command = [
|
||||||
"python3",
|
"python3",
|
||||||
"-m",
|
"-m",
|
||||||
"sglang.launch_server",
|
command,
|
||||||
"--model-path",
|
"--model-path",
|
||||||
model,
|
model,
|
||||||
|
*[str(x) for x in other_args],
|
||||||
|
]
|
||||||
|
|
||||||
|
if pd_seperated:
|
||||||
|
command.extend(
|
||||||
|
[
|
||||||
|
"--lb-host",
|
||||||
|
host,
|
||||||
|
"--lb-port",
|
||||||
|
port,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
command.extend(
|
||||||
|
[
|
||||||
"--host",
|
"--host",
|
||||||
host,
|
host,
|
||||||
"--port",
|
"--port",
|
||||||
port,
|
port,
|
||||||
*other_args,
|
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
command += ["--api-key", api_key]
|
command += ["--api-key", api_key]
|
||||||
|
|
||||||
|
print(f"command={' '.join(command)}")
|
||||||
|
|
||||||
if return_stdout_stderr:
|
if return_stdout_stderr:
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
@@ -456,6 +480,8 @@ def popen_launch_server(
|
|||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
pass
|
pass
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
|
kill_process_tree(process.pid)
|
||||||
raise TimeoutError("Server failed to start within the timeout period.")
|
raise TimeoutError("Server failed to start within the timeout period.")
|
||||||
|
|
||||||
|
|
||||||
@@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
|||||||
success = True
|
success = True
|
||||||
|
|
||||||
for filename in files:
|
for filename in files:
|
||||||
global process
|
process = None
|
||||||
|
|
||||||
def run_one_file(filename):
|
def run_one_file(filename):
|
||||||
|
nonlocal process
|
||||||
|
|
||||||
filename = os.path.join(os.getcwd(), filename)
|
filename = os.path.join(os.getcwd(), filename)
|
||||||
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
|
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
@@ -534,11 +562,14 @@ def get_benchmark_args(
|
|||||||
dataset_path="",
|
dataset_path="",
|
||||||
tokenizer="",
|
tokenizer="",
|
||||||
num_prompts=500,
|
num_prompts=500,
|
||||||
|
sharegpt_output_len=None,
|
||||||
random_input_len=4096,
|
random_input_len=4096,
|
||||||
random_output_len=2048,
|
random_output_len=2048,
|
||||||
|
sharegpt_context_len=None,
|
||||||
request_rate=float("inf"),
|
request_rate=float("inf"),
|
||||||
disable_stream=False,
|
disable_stream=False,
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
|
pd_seperated: bool = False,
|
||||||
):
|
):
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
backend="sglang",
|
backend="sglang",
|
||||||
@@ -550,8 +581,8 @@ def get_benchmark_args(
|
|||||||
model=None,
|
model=None,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_prompts=num_prompts,
|
num_prompts=num_prompts,
|
||||||
sharegpt_output_len=None,
|
sharegpt_output_len=sharegpt_output_len,
|
||||||
sharegpt_context_len=None,
|
sharegpt_context_len=sharegpt_context_len,
|
||||||
random_input_len=random_input_len,
|
random_input_len=random_input_len,
|
||||||
random_output_len=random_output_len,
|
random_output_len=random_output_len,
|
||||||
random_range_ratio=0.0,
|
random_range_ratio=0.0,
|
||||||
@@ -567,6 +598,8 @@ def get_benchmark_args(
|
|||||||
apply_chat_template=False,
|
apply_chat_template=False,
|
||||||
profile=None,
|
profile=None,
|
||||||
lora_name=None,
|
lora_name=None,
|
||||||
|
prompt_suffix="",
|
||||||
|
pd_seperated=pd_seperated,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -580,6 +613,7 @@ def run_bench_serving(
|
|||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
random_input_len=4096,
|
random_input_len=4096,
|
||||||
random_output_len=2048,
|
random_output_len=2048,
|
||||||
|
sharegpt_context_len=None,
|
||||||
disable_stream=False,
|
disable_stream=False,
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
need_warmup=False,
|
need_warmup=False,
|
||||||
@@ -602,6 +636,7 @@ def run_bench_serving(
|
|||||||
num_prompts=num_prompts,
|
num_prompts=num_prompts,
|
||||||
random_input_len=random_input_len,
|
random_input_len=random_input_len,
|
||||||
random_output_len=random_output_len,
|
random_output_len=random_output_len,
|
||||||
|
sharegpt_context_len=sharegpt_context_len,
|
||||||
request_rate=request_rate,
|
request_rate=request_rate,
|
||||||
disable_stream=disable_stream,
|
disable_stream=disable_stream,
|
||||||
disable_ignore_eos=disable_ignore_eos,
|
disable_ignore_eos=disable_ignore_eos,
|
||||||
@@ -626,6 +661,7 @@ def run_bench_serving_multi(
|
|||||||
other_server_args,
|
other_server_args,
|
||||||
benchmark_args,
|
benchmark_args,
|
||||||
need_warmup=False,
|
need_warmup=False,
|
||||||
|
pd_seperated=False,
|
||||||
):
|
):
|
||||||
# Launch the server
|
# Launch the server
|
||||||
process = popen_launch_server(
|
process = popen_launch_server(
|
||||||
@@ -633,6 +669,7 @@ def run_bench_serving_multi(
|
|||||||
base_url,
|
base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=other_server_args,
|
other_args=other_server_args,
|
||||||
|
pd_seperated=pd_seperated,
|
||||||
)
|
)
|
||||||
|
|
||||||
# run benchmark for all
|
# run benchmark for all
|
||||||
@@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args):
|
|||||||
"128",
|
"128",
|
||||||
"--output",
|
"--output",
|
||||||
"8",
|
"8",
|
||||||
*other_args,
|
*[str(x) for x in other_args],
|
||||||
]
|
]
|
||||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
|
||||||
@@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None):
|
|||||||
stdout = open(STDOUT_FILENAME, "w")
|
stdout = open(STDOUT_FILENAME, "w")
|
||||||
stderr = open(STDERR_FILENAME, "w")
|
stderr = open(STDERR_FILENAME, "w")
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
command, stdout=stdout, stderr=stderr, env=env, text=True
|
command, stdout=stdout, stderr=stdout, env=env, text=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Launch a thread to stream the output
|
# Launch a thread to stream the output
|
||||||
@@ -914,3 +951,78 @@ def run_mulit_request_test(
|
|||||||
def write_github_step_summary(content):
|
def write_github_step_summary(content):
|
||||||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
|
|
||||||
|
def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
||||||
|
(
|
||||||
|
input_len,
|
||||||
|
output_len,
|
||||||
|
temperature,
|
||||||
|
logprob_start_len,
|
||||||
|
return_logprob,
|
||||||
|
top_logprobs_num,
|
||||||
|
) = arg
|
||||||
|
input_ids = list(range(input_len))
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_new_tokens": output_len,
|
||||||
|
"ignore_eos": True,
|
||||||
|
},
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"logprob_start_len": logprob_start_len,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
res = response_json
|
||||||
|
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
|
||||||
|
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
|
||||||
|
|
||||||
|
# Test the number of tokens are correct
|
||||||
|
if return_logprob:
|
||||||
|
self.assertEqual(
|
||||||
|
len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
|
||||||
|
res["meta_info"]["prompt_tokens"],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
|
||||||
|
|
||||||
|
if top_logprobs_num:
|
||||||
|
self.assertEqual(
|
||||||
|
len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
|
||||||
|
res["meta_info"]["prompt_tokens"],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)
|
||||||
|
|
||||||
|
for i in range(output_len):
|
||||||
|
self.assertEqual(
|
||||||
|
len(res["meta_info"]["output_top_logprobs"][i]),
|
||||||
|
top_logprobs_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test the top-1 tokens are the same as output tokens if temperature == 0
|
||||||
|
if temperature == 0:
|
||||||
|
rank = 0
|
||||||
|
while rank < len(res["meta_info"]["output_top_logprobs"][i]):
|
||||||
|
try:
|
||||||
|
self.assertListEqual(
|
||||||
|
res["meta_info"]["output_token_logprobs"][i],
|
||||||
|
res["meta_info"]["output_top_logprobs"][i][rank],
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except AssertionError:
|
||||||
|
# There's a tie. Allow the second item in this case.
|
||||||
|
if (
|
||||||
|
res["meta_info"]["output_top_logprobs"][i][rank][0]
|
||||||
|
== res["meta_info"]["output_top_logprobs"][i][rank + 1][
|
||||||
|
0
|
||||||
|
]
|
||||||
|
):
|
||||||
|
rank += 1
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|||||||
@@ -1,14 +1,5 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# Check if sudo is available
|
|
||||||
if command -v sudo >/dev/null 2>&1; then
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y lsof
|
|
||||||
else
|
|
||||||
apt-get update
|
|
||||||
apt-get install -y lsof
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Show current GPU status
|
# Show current GPU status
|
||||||
nvidia-smi
|
nvidia-smi
|
||||||
|
|
||||||
@@ -20,6 +11,14 @@ kill -9 $(ps aux | grep 'sglang.data_parallel' | grep -v 'grep' | awk '{print $2
|
|||||||
|
|
||||||
# Clean all GPU processes if any argument is provided
|
# Clean all GPU processes if any argument is provided
|
||||||
if [ $# -gt 0 ]; then
|
if [ $# -gt 0 ]; then
|
||||||
|
# Check if sudo is available
|
||||||
|
if command -v sudo >/dev/null 2>&1; then
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y lsof
|
||||||
|
else
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y lsof
|
||||||
|
fi
|
||||||
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
|
kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null
|
||||||
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
|
lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null
|
||||||
fi
|
fi
|
||||||
|
|||||||
257
scripts/playground/bench_speculative.py
Normal file
257
scripts/playground/bench_speculative.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
# single GPU
|
||||||
|
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.bench_serving import benchmark, set_global_args
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
kill_process_tree,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def node0_print(msg):
|
||||||
|
if server_args.node_rank == 0:
|
||||||
|
print(msg)
|
||||||
|
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Human: Give me a fully functional FastAPI server. Show the full, long python code without stop.\n\nAssistant:",
|
||||||
|
"Human: Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation.\n\nAssistant:",
|
||||||
|
"Human: Write a travel blog post to Hawaii.\n\nAssistant:",
|
||||||
|
"Human: I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. My first sentence is 'istanbulu cok seviyom burada olmak cok guzel'. Answer in more than 5000 words.\n\nAssistant:",
|
||||||
|
"Human: I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it’s children then you can talk about animals; If it’s adults then history-based tales might engage them better etc. Answer in more than 5000 words. My first request is 'I need an interesting story on perseverance.'\n\nAssistant:",
|
||||||
|
"Human: Solve x^2 = -1. Think step-by-step. Give me a long detailed explanation. \n\nAssistant:",
|
||||||
|
"Human: Tell me about the president of the USA in wikipedia style.\n\nAssistant:",
|
||||||
|
"Human: Hello? Who are you? Write code, math, and poem to explanin yourself.\n\nAssistant:",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTokenizer:
|
||||||
|
def encode(self, text: str, add_special_tokens: bool = False):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def send_one_batch(base_url, num_prompts, batch_size):
|
||||||
|
padded_prompts = (prompts * ((num_prompts + len(prompts) - 1) // len(prompts)))[
|
||||||
|
:num_prompts
|
||||||
|
]
|
||||||
|
|
||||||
|
# format: (prompt, input_len, output len). We set input_len as a dummy value 0.
|
||||||
|
input_requests = [(p, 0, 512) for p in padded_prompts]
|
||||||
|
|
||||||
|
# We need to set some dummy values in order to call `benchmark` below.
|
||||||
|
args = SimpleNamespace(
|
||||||
|
disable_ignore_eos=False,
|
||||||
|
disable_stream=False,
|
||||||
|
return_logprob=False,
|
||||||
|
backend="sglang",
|
||||||
|
dataset_name="custom",
|
||||||
|
num_prompts=None,
|
||||||
|
sharegpt_output_len=None,
|
||||||
|
random_input_len=None,
|
||||||
|
random_output_len=None,
|
||||||
|
random_range_ratio=None,
|
||||||
|
output_file=None,
|
||||||
|
)
|
||||||
|
set_global_args(args)
|
||||||
|
tokenizer = FakeTokenizer()
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
results = asyncio.run(
|
||||||
|
benchmark(
|
||||||
|
backend="sglang",
|
||||||
|
api_url=f"{base_url}/generate",
|
||||||
|
base_url=base_url,
|
||||||
|
model_id="default",
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_requests=input_requests,
|
||||||
|
request_rate=float("inf"),
|
||||||
|
max_concurrency=batch_size,
|
||||||
|
disable_tqdm=False,
|
||||||
|
lora_name=None,
|
||||||
|
extra_request_body={},
|
||||||
|
profile=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert results["completed"] == len(input_requests)
|
||||||
|
acc_length = results["accept_length"] or 1.0
|
||||||
|
avg_output_token = results["total_output_tokens"] / results["completed"]
|
||||||
|
|
||||||
|
server_info = requests.get(base_url + "/get_server_info").json()
|
||||||
|
# We use 20% percentile instead of median on purpose
|
||||||
|
step_time = np.percentile(server_info["step_time_dict"][str(batch_size)], 20)
|
||||||
|
speed = 1 / step_time * acc_length
|
||||||
|
|
||||||
|
return (
|
||||||
|
round(acc_length, 3),
|
||||||
|
round(step_time, 5),
|
||||||
|
round(speed, 3),
|
||||||
|
avg_output_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args, server_args):
|
||||||
|
base_url = "http://127.0.0.1:20000"
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
for batch_size in args.batch_size:
|
||||||
|
for steps in args.steps:
|
||||||
|
for topk in args.topk:
|
||||||
|
for num_draft_tokens in args.num_draft_tokens:
|
||||||
|
if steps * topk + 1 < num_draft_tokens:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (steps == 0 or topk == 0 or num_draft_tokens == 0) and (
|
||||||
|
steps + topk + num_draft_tokens != 0
|
||||||
|
):
|
||||||
|
# steps == 0 and topk == 0 and num_draft_tokens == 0 is a special case for non-speculative decoding.
|
||||||
|
continue
|
||||||
|
|
||||||
|
configs.append((batch_size, steps, topk, num_draft_tokens))
|
||||||
|
|
||||||
|
for i in range(args.start, args.end or len(configs)):
|
||||||
|
batch_size, steps, topk, num_draft_tokens = configs[i]
|
||||||
|
|
||||||
|
node0_print(
|
||||||
|
f"Start {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
if steps == 0:
|
||||||
|
other_args = []
|
||||||
|
else:
|
||||||
|
other_args = [
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-num-steps",
|
||||||
|
steps,
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
topk,
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
num_draft_tokens,
|
||||||
|
]
|
||||||
|
if server_args.speculative_draft_model_path is not None:
|
||||||
|
other_args.extend(
|
||||||
|
[
|
||||||
|
"--speculative-draft-model-path",
|
||||||
|
server_args.speculative_draft_model_path,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
other_args.extend(
|
||||||
|
[
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
batch_size,
|
||||||
|
"--mem-fraction-static",
|
||||||
|
server_args.mem_fraction_static,
|
||||||
|
"--tp-size",
|
||||||
|
server_args.tp_size,
|
||||||
|
"--max-running-requests",
|
||||||
|
batch_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if server_args.quantization:
|
||||||
|
other_args.extend(
|
||||||
|
[
|
||||||
|
"--quantization",
|
||||||
|
server_args.quantization,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
process = popen_launch_server(
|
||||||
|
args.model_path,
|
||||||
|
base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
env={
|
||||||
|
"SGLANG_RECORD_STEP_TIME": "1",
|
||||||
|
**os.environ,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Warmup
|
||||||
|
send_one_batch(base_url, batch_size, batch_size)
|
||||||
|
|
||||||
|
# Benchmark
|
||||||
|
acc_length, step_time, speed, completion_tokens = send_one_batch(
|
||||||
|
base_url, max(args.num_prompts, batch_size), batch_size
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
node0_print(
|
||||||
|
f"Finish {i=}: {batch_size=}, {steps=}, {topk=}, {num_draft_tokens=}, {speed=:.2f} token/s, step_time={step_time * 1000:.2f} ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"steps": steps,
|
||||||
|
"topk": topk,
|
||||||
|
"num_draft_tokens": num_draft_tokens,
|
||||||
|
"acc_length": acc_length,
|
||||||
|
"step_time": step_time,
|
||||||
|
"speed": speed,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(args.output, "a") as fout:
|
||||||
|
fout.write(json.dumps(record) + "\n")
|
||||||
|
|
||||||
|
# Wait for the server to shutdown
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||||
|
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
ServerArgs.add_cli_args(parser)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=(1, 2, 4, 8, 16),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=(0, 1, 3, 5, 7), # use (0, 1, 2, 3, 4) for large batch size
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--topk",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=(0, 1, 2, 4, 8),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_draft_tokens",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=(0, 2, 4, 8, 16, 32), # use (0, 2, 4, 8) for large batch size
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-prompts", type=int, default=16)
|
||||||
|
parser.add_argument("--start", type=int, default=0)
|
||||||
|
parser.add_argument("--end", type=int)
|
||||||
|
parser.add_argument("--output", type=str, default="output.jsonl")
|
||||||
|
args = parser.parse_args()
|
||||||
|
server_args: ServerArgs = ServerArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
main(args, server_args)
|
||||||
@@ -111,6 +111,8 @@ else:
|
|||||||
"cublas_grouped_gemm",
|
"cublas_grouped_gemm",
|
||||||
"custom_dispose",
|
"custom_dispose",
|
||||||
"custom_reduce",
|
"custom_reduce",
|
||||||
|
"build_tree_kernel_efficient",
|
||||||
|
"build_tree_kernel",
|
||||||
"fp8_blockwise_scaled_mm",
|
"fp8_blockwise_scaled_mm",
|
||||||
"fp8_scaled_mm",
|
"fp8_scaled_mm",
|
||||||
"fused_add_rmsnorm",
|
"fused_add_rmsnorm",
|
||||||
@@ -127,12 +129,10 @@ else:
|
|||||||
"register_graph_buffers",
|
"register_graph_buffers",
|
||||||
"rmsnorm",
|
"rmsnorm",
|
||||||
"sampling_scaling_penalties",
|
"sampling_scaling_penalties",
|
||||||
|
"sgl_per_token_group_quant_fp8",
|
||||||
"silu_and_mul",
|
"silu_and_mul",
|
||||||
"top_k_renorm_prob",
|
"top_k_renorm_prob",
|
||||||
"top_k_top_p_sampling_from_probs",
|
"top_k_top_p_sampling_from_probs",
|
||||||
"top_p_renorm_prob",
|
"top_p_renorm_prob",
|
||||||
"tree_speculative_sampling_target_only",
|
"tree_speculative_sampling_target_only",
|
||||||
"build_tree_kernel_efficient",
|
|
||||||
"build_tree_kernel",
|
|
||||||
"sgl_per_token_group_quant_fp8",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -30,7 +30,9 @@ class TestSRTBackend(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST)
|
cls.backend = sgl.Runtime(
|
||||||
|
model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4
|
||||||
|
)
|
||||||
sgl.set_default_backend(cls.backend)
|
sgl.set_default_backend(cls.backend)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ suites = {
|
|||||||
"models/test_generation_models.py",
|
"models/test_generation_models.py",
|
||||||
"models/test_qwen_models.py",
|
"models/test_qwen_models.py",
|
||||||
"models/test_reward_models.py",
|
"models/test_reward_models.py",
|
||||||
"sampling/penaltylib",
|
|
||||||
"test_abort.py",
|
"test_abort.py",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
"test_custom_allreduce.py",
|
"test_custom_allreduce.py",
|
||||||
@@ -31,6 +30,7 @@ suites = {
|
|||||||
"test_no_chunked_prefill.py",
|
"test_no_chunked_prefill.py",
|
||||||
"test_no_overlap_scheduler.py",
|
"test_no_overlap_scheduler.py",
|
||||||
"test_openai_server.py",
|
"test_openai_server.py",
|
||||||
|
"test_penalty.py",
|
||||||
"test_pytorch_sampling_backend.py",
|
"test_pytorch_sampling_backend.py",
|
||||||
"test_radix_attention.py",
|
"test_radix_attention.py",
|
||||||
"test_regex_constrained.py",
|
"test_regex_constrained.py",
|
||||||
@@ -38,7 +38,8 @@ suites = {
|
|||||||
"test_request_length_validation.py",
|
"test_request_length_validation.py",
|
||||||
"test_retract_decode.py",
|
"test_retract_decode.py",
|
||||||
"test_server_args.py",
|
"test_server_args.py",
|
||||||
"test_session_control.py",
|
# Disabled temporarily
|
||||||
|
# "test_session_control.py",
|
||||||
"test_skip_tokenizer_init.py",
|
"test_skip_tokenizer_init.py",
|
||||||
"test_srt_engine.py",
|
"test_srt_engine.py",
|
||||||
"test_srt_endpoint.py",
|
"test_srt_endpoint.py",
|
||||||
@@ -64,9 +65,6 @@ suites = {
|
|||||||
# Disable temporarily
|
# Disable temporarily
|
||||||
# "test_nightly_math_eval.py",
|
# "test_nightly_math_eval.py",
|
||||||
],
|
],
|
||||||
"sampling/penaltylib": glob.glob(
|
|
||||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Expand suite
|
# Expand suite
|
||||||
@@ -83,7 +81,7 @@ if __name__ == "__main__":
|
|||||||
arg_parser.add_argument(
|
arg_parser.add_argument(
|
||||||
"--timeout-per-file",
|
"--timeout-per-file",
|
||||||
type=int,
|
type=int,
|
||||||
default=2000,
|
default=1800,
|
||||||
help="The time limit for running one file in seconds.",
|
help="The time limit for running one file in seconds.",
|
||||||
)
|
)
|
||||||
arg_parser.add_argument(
|
arg_parser.add_argument(
|
||||||
|
|||||||
@@ -1,97 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import (
|
|
||||||
BatchedFrequencyPenalizer,
|
|
||||||
)
|
|
||||||
from sglang.test.srt.sampling.penaltylib.utils import (
|
|
||||||
BaseBatchedPenalizerTest,
|
|
||||||
MockSamplingParams,
|
|
||||||
Step,
|
|
||||||
StepType,
|
|
||||||
Subject,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
|
|
||||||
Penalizer = BatchedFrequencyPenalizer
|
|
||||||
frequency_penalty: float
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
if self.__class__ == BaseBatchedFrequencyPenalizerTest:
|
|
||||||
self.skipTest("Base class for frequency_penalty tests")
|
|
||||||
|
|
||||||
super().setUp()
|
|
||||||
|
|
||||||
def _create_subject(self, frequency_penalty: float) -> Subject:
|
|
||||||
return Subject(
|
|
||||||
sampling_params=MockSamplingParams(
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
),
|
|
||||||
steps=[
|
|
||||||
Step(
|
|
||||||
type=StepType.INPUT,
|
|
||||||
token_ids=[0, 1, 2],
|
|
||||||
expected_tensors={
|
|
||||||
"frequency_penalties": self.tensor(
|
|
||||||
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
|
|
||||||
),
|
|
||||||
"cumulated_frequency_penalties": self.tensor(
|
|
||||||
[[0.0] * self.vocab_size], dtype=torch.float32
|
|
||||||
),
|
|
||||||
},
|
|
||||||
expected_logits=self.tensor(
|
|
||||||
[[1] * self.vocab_size], dtype=torch.float32
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Step(
|
|
||||||
type=StepType.OUTPUT,
|
|
||||||
token_ids=[
|
|
||||||
1,
|
|
||||||
2,
|
|
||||||
2,
|
|
||||||
], # This is the output ids of one request in three steps.
|
|
||||||
expected_tensors={
|
|
||||||
"frequency_penalties": self.tensor(
|
|
||||||
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
|
|
||||||
),
|
|
||||||
"cumulated_frequency_penalties": self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
frequency_penalty * i if i in {1, 2} else 0.0
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
],
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
expected_logits=self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
1.0 - frequency_penalty * i if i in {1, 2} else 1.0
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
],
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_test_subjects(self) -> List[Subject]:
|
|
||||||
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
|
|
||||||
self.disabled = self._create_subject(frequency_penalty=0.0)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest):
|
|
||||||
frequency_penalty = 0.12
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest):
|
|
||||||
frequency_penalty = -0.12
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,152 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.penalizers.min_new_tokens import (
|
|
||||||
BatchedMinNewTokensPenalizer,
|
|
||||||
)
|
|
||||||
from sglang.test.srt.sampling.penaltylib.utils import (
|
|
||||||
BaseBatchedPenalizerTest,
|
|
||||||
MockSamplingParams,
|
|
||||||
Step,
|
|
||||||
StepType,
|
|
||||||
Subject,
|
|
||||||
)
|
|
||||||
|
|
||||||
MIN_NEW_TOKENS = 2
|
|
||||||
EOS_TOKEN_ID = 4
|
|
||||||
STOP_TOKEN_ID = 3
|
|
||||||
|
|
||||||
ALL_STOP_TOKEN_IDS = {STOP_TOKEN_ID, EOS_TOKEN_ID}
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
|
|
||||||
Penalizer = BatchedMinNewTokensPenalizer
|
|
||||||
|
|
||||||
def _create_subject(self, min_new_tokens: int) -> Subject:
|
|
||||||
return Subject(
|
|
||||||
eos_token_id=EOS_TOKEN_ID,
|
|
||||||
sampling_params=MockSamplingParams(
|
|
||||||
min_new_tokens=min_new_tokens,
|
|
||||||
stop_token_ids={STOP_TOKEN_ID},
|
|
||||||
),
|
|
||||||
steps=[
|
|
||||||
Step(
|
|
||||||
type=StepType.INPUT,
|
|
||||||
token_ids=[0, 1, 2],
|
|
||||||
expected_tensors={
|
|
||||||
"min_new_tokens": self.tensor(
|
|
||||||
[[min_new_tokens]], dtype=torch.int32
|
|
||||||
),
|
|
||||||
"stop_token_penalties": self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
]
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
"len_output_tokens": self.tensor([[0]], dtype=torch.int32),
|
|
||||||
},
|
|
||||||
expected_logits=(
|
|
||||||
self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
]
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
if min_new_tokens > 0
|
|
||||||
else torch.ones(
|
|
||||||
(1, self.vocab_size),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Step(
|
|
||||||
type=StepType.OUTPUT,
|
|
||||||
token_ids=[0],
|
|
||||||
expected_tensors={
|
|
||||||
"min_new_tokens": self.tensor(
|
|
||||||
[[min_new_tokens]], dtype=torch.int32
|
|
||||||
),
|
|
||||||
"stop_token_penalties": self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
]
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
"len_output_tokens": self.tensor([[1]], dtype=torch.int32),
|
|
||||||
},
|
|
||||||
expected_logits=(
|
|
||||||
self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
]
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
if min_new_tokens > 1
|
|
||||||
else torch.ones(
|
|
||||||
(1, self.vocab_size),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Step(
|
|
||||||
type=StepType.OUTPUT,
|
|
||||||
token_ids=[0],
|
|
||||||
expected_tensors={
|
|
||||||
"min_new_tokens": self.tensor(
|
|
||||||
[[min_new_tokens]], dtype=torch.int32
|
|
||||||
),
|
|
||||||
"stop_token_penalties": self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
]
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
"len_output_tokens": self.tensor([[2]], dtype=torch.int32),
|
|
||||||
},
|
|
||||||
expected_logits=(
|
|
||||||
self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
]
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
if min_new_tokens > 2
|
|
||||||
else torch.ones(
|
|
||||||
(1, self.vocab_size),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_test_subjects(self) -> List[Subject]:
|
|
||||||
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
|
|
||||||
self.disabled = self._create_subject(min_new_tokens=0.0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.penalizers.presence_penalty import (
|
|
||||||
BatchedPresencePenalizer,
|
|
||||||
)
|
|
||||||
from sglang.test.srt.sampling.penaltylib.utils import (
|
|
||||||
BaseBatchedPenalizerTest,
|
|
||||||
MockSamplingParams,
|
|
||||||
Step,
|
|
||||||
StepType,
|
|
||||||
Subject,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
|
|
||||||
Penalizer = BatchedPresencePenalizer
|
|
||||||
presence_penalty: float
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
if self.__class__ == BaseBatchedPresencePenalizerTest:
|
|
||||||
self.skipTest("Base class for presence_penalty tests")
|
|
||||||
|
|
||||||
super().setUp()
|
|
||||||
|
|
||||||
def _create_subject(self, presence_penalty: float) -> Subject:
|
|
||||||
return Subject(
|
|
||||||
sampling_params=MockSamplingParams(
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
),
|
|
||||||
steps=[
|
|
||||||
Step(
|
|
||||||
type=StepType.INPUT,
|
|
||||||
token_ids=[0, 1, 2],
|
|
||||||
expected_tensors={
|
|
||||||
"presence_penalties": self.tensor(
|
|
||||||
[[presence_penalty] * self.vocab_size], dtype=torch.float32
|
|
||||||
),
|
|
||||||
"cumulated_presence_penalties": self.tensor(
|
|
||||||
[[0.0] * self.vocab_size], dtype=torch.float32
|
|
||||||
),
|
|
||||||
},
|
|
||||||
expected_logits=self.tensor(
|
|
||||||
[[1] * self.vocab_size], dtype=torch.float32
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Step(
|
|
||||||
type=StepType.OUTPUT,
|
|
||||||
token_ids=[1, 2, 2],
|
|
||||||
expected_tensors={
|
|
||||||
"presence_penalties": self.tensor(
|
|
||||||
[[presence_penalty] * self.vocab_size], dtype=torch.float32
|
|
||||||
),
|
|
||||||
"cumulated_presence_penalties": self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
presence_penalty if i in {1, 2} else 0.0
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
],
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
expected_logits=self.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
1.0 - presence_penalty if i in {1, 2} else 1.0
|
|
||||||
for i in range(self.vocab_size)
|
|
||||||
],
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_test_subjects(self) -> List[Subject]:
|
|
||||||
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
|
|
||||||
self.disabled = self._create_subject(presence_penalty=0.0)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest):
|
|
||||||
presence_penalty = 0.12
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest):
|
|
||||||
presence_penalty = -0.12
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
|
|
||||||
BatchedRepetitionPenalizer,
|
|
||||||
)
|
|
||||||
from sglang.test.srt.sampling.penaltylib.utils import (
|
|
||||||
BaseBatchedPenalizerTest,
|
|
||||||
MockSamplingParams,
|
|
||||||
Step,
|
|
||||||
StepType,
|
|
||||||
Subject,
|
|
||||||
)
|
|
||||||
|
|
||||||
REPETITION_PENALTY = 2.0
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
|
|
||||||
Penalizer = BatchedRepetitionPenalizer
|
|
||||||
|
|
||||||
def _create_subject(self, repetition_penalty: float) -> Subject:
|
|
||||||
l = 1.0 / repetition_penalty
|
|
||||||
return Subject(
|
|
||||||
sampling_params=MockSamplingParams(
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
),
|
|
||||||
steps=[
|
|
||||||
Step(
|
|
||||||
type=StepType.INPUT,
|
|
||||||
token_ids=[0, 1, 2],
|
|
||||||
expected_tensors={
|
|
||||||
"repetition_penalties": self.tensor(
|
|
||||||
[[repetition_penalty] * self.vocab_size],
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
"cumulated_repetition_penalties": (
|
|
||||||
self.tensor(
|
|
||||||
[[2.0, 2.0, 2.0, 1.0, 1.0]], dtype=torch.float32
|
|
||||||
)
|
|
||||||
if repetition_penalty != 1.0
|
|
||||||
else self.tensor(
|
|
||||||
[[1.0] * self.vocab_size], dtype=torch.float32
|
|
||||||
)
|
|
||||||
),
|
|
||||||
},
|
|
||||||
expected_logits=(
|
|
||||||
self.tensor([[l, l, l, 1.0, 1.0]], dtype=torch.float32)
|
|
||||||
if repetition_penalty != 1.0
|
|
||||||
else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
Step(
|
|
||||||
type=StepType.OUTPUT,
|
|
||||||
token_ids=[0, 1, 3],
|
|
||||||
expected_tensors={
|
|
||||||
"repetition_penalties": self.tensor(
|
|
||||||
[[repetition_penalty] * self.vocab_size],
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
"cumulated_repetition_penalties": (
|
|
||||||
self.tensor(
|
|
||||||
[[2.0, 2.0, 2.0, 2.0, 1.0]], dtype=torch.float32
|
|
||||||
)
|
|
||||||
if repetition_penalty != 1.0
|
|
||||||
else self.tensor(
|
|
||||||
[[1.0] * self.vocab_size], dtype=torch.float32
|
|
||||||
)
|
|
||||||
),
|
|
||||||
},
|
|
||||||
expected_logits=(
|
|
||||||
self.tensor([[l, l, l, l, 1.0]], dtype=torch.float32)
|
|
||||||
if repetition_penalty != 1.0
|
|
||||||
else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_test_subjects(self) -> List[Subject]:
|
|
||||||
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
|
|
||||||
self.disabled = self._create_subject(repetition_penalty=1.0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
import json
|
|
||||||
import unittest
|
|
||||||
from multiprocessing import Process
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
|
||||||
from sglang.test.test_utils import (
|
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
popen_launch_server,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchPenalizerE2E(unittest.TestCase):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model,
|
|
||||||
cls.base_url,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
other_args=(
|
|
||||||
"--random-seed",
|
|
||||||
"0",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
kill_process_tree(cls.process.pid)
|
|
||||||
|
|
||||||
def run_decode(
|
|
||||||
self,
|
|
||||||
return_logprob=True,
|
|
||||||
top_logprobs_num=5,
|
|
||||||
return_text=True,
|
|
||||||
n=1,
|
|
||||||
**sampling_params,
|
|
||||||
):
|
|
||||||
response = requests.post(
|
|
||||||
self.base_url + "/generate",
|
|
||||||
json={
|
|
||||||
# prompt that is supposed to generate < 32 tokens
|
|
||||||
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
|
||||||
"sampling_params": {
|
|
||||||
"max_new_tokens": 32,
|
|
||||||
"n": n,
|
|
||||||
**sampling_params,
|
|
||||||
},
|
|
||||||
"stream": False,
|
|
||||||
"return_logprob": return_logprob,
|
|
||||||
"top_logprobs_num": top_logprobs_num,
|
|
||||||
"return_text_in_logprobs": return_text,
|
|
||||||
"logprob_start_len": 0,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, "Request failed: " + response.text
|
|
||||||
|
|
||||||
def test_default_values(self):
|
|
||||||
self.run_decode()
|
|
||||||
|
|
||||||
def test_mixed(self):
|
|
||||||
"""
|
|
||||||
Sends two requests with one with penalizers disabled, and the other with penalizers enabled.
|
|
||||||
This will cause two different {ScheduleBatch} to be initialized and eventually gets merged.
|
|
||||||
|
|
||||||
Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not.
|
|
||||||
This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge.
|
|
||||||
|
|
||||||
This test triggers the merge of disabled + enabled.
|
|
||||||
"""
|
|
||||||
|
|
||||||
processes = []
|
|
||||||
|
|
||||||
p = Process(
|
|
||||||
target=self.run_decode,
|
|
||||||
)
|
|
||||||
processes.append(p)
|
|
||||||
p.start()
|
|
||||||
|
|
||||||
p = Process(
|
|
||||||
target=self.run_decode,
|
|
||||||
kwargs={
|
|
||||||
"frequency_penalty": 2,
|
|
||||||
"min_new_tokens": 16,
|
|
||||||
"presence_penalty": 2,
|
|
||||||
"repetition_penalty": 2,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
processes.append(p)
|
|
||||||
p.start()
|
|
||||||
|
|
||||||
for p in processes:
|
|
||||||
p.join()
|
|
||||||
|
|
||||||
def test_frequency_penalty(self):
|
|
||||||
self.run_decode(frequency_penalty=2)
|
|
||||||
|
|
||||||
def test_min_new_tokens(self):
|
|
||||||
self.run_decode(min_new_tokens=16)
|
|
||||||
|
|
||||||
def test_presence_penalty(self):
|
|
||||||
self.run_decode(presence_penalty=2)
|
|
||||||
|
|
||||||
def test_repetition_penalty(self):
|
|
||||||
self.run_decode(repetition_penalty=2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main(verbosity=3)
|
|
||||||
@@ -138,6 +138,7 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
num_prompts=50,
|
num_prompts=50,
|
||||||
request_rate=1,
|
request_rate=1,
|
||||||
|
sharegpt_context_len=3072,
|
||||||
disable_ignore_eos=True,
|
disable_ignore_eos=True,
|
||||||
dataset_name="sharegpt",
|
dataset_name="sharegpt",
|
||||||
other_server_args=[
|
other_server_args=[
|
||||||
@@ -148,22 +149,23 @@ class TestBenchServing(unittest.TestCase):
|
|||||||
"--speculative-num-steps",
|
"--speculative-num-steps",
|
||||||
"5",
|
"5",
|
||||||
"--speculative-eagle-topk",
|
"--speculative-eagle-topk",
|
||||||
"8",
|
"4",
|
||||||
"--speculative-num-draft-tokens",
|
"--speculative-num-draft-tokens",
|
||||||
"64",
|
"16",
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
"0.7",
|
"0.7",
|
||||||
"--cuda-graph-max-bs",
|
|
||||||
"32",
|
|
||||||
],
|
],
|
||||||
|
need_warmup=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
write_github_step_summary(
|
write_github_step_summary(
|
||||||
f"### test_online_latency_eagle\n"
|
f"### test_online_latency_eagle\n"
|
||||||
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
|
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
|
||||||
|
f'accept_length : {res["accept_length"]:.2f} \n'
|
||||||
)
|
)
|
||||||
self.assertLess(res["median_e2e_latency_ms"], 450)
|
self.assertLess(res["median_e2e_latency_ms"], 700)
|
||||||
|
self.assertGreater(res["accept_length"], 2.50)
|
||||||
|
|
||||||
def test_moe_offline_throughput_default(self):
|
def test_moe_offline_throughput_default(self):
|
||||||
res = run_bench_serving(
|
res = run_bench_serving(
|
||||||
|
|||||||
@@ -12,7 +12,9 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
is_in_ci,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
|
write_github_step_summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -44,6 +46,9 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreater(metrics["score"], 0.71)
|
self.assertGreater(metrics["score"], 0.71)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
|
||||||
|
|
||||||
def test_human_eval(self):
|
def test_human_eval(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
@@ -56,6 +61,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreater(metrics["score"], 0.64)
|
self.assertGreater(metrics["score"], 0.64)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
write_github_step_summary(
|
||||||
|
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
|
||||||
|
)
|
||||||
|
|
||||||
def test_mgsm_en(self):
|
def test_mgsm_en(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
@@ -68,6 +78,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
self.assertGreater(metrics["score"], 0.835)
|
self.assertGreater(metrics["score"], 0.835)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
write_github_step_summary(
|
||||||
|
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
27
test/srt/test_health_check.py
Normal file
27
test/srt/test_health_check.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheck(unittest.TestCase):
|
||||||
|
def test_health_check(self):
|
||||||
|
"""Test that metrics endpoint returns data when enabled"""
|
||||||
|
with self.assertRaises(TimeoutError):
|
||||||
|
popen_launch_server(
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
timeout=60,
|
||||||
|
other_args=[
|
||||||
|
"--disable-cuda-graph",
|
||||||
|
"--json-model-override-args",
|
||||||
|
'{"architectures": ["LlamaForCausalLMForHealthTest"]}',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -49,7 +49,7 @@ class TestHiddenState(unittest.TestCase):
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
hf_out = model(
|
hf_out = model(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[input_id + output["token_ids"][:-1]], device=model.device
|
[input_id + output["output_ids"][:-1]], device=model.device
|
||||||
),
|
),
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,11 +56,13 @@ class TestEnableMetrics(unittest.TestCase):
|
|||||||
"sglang:gen_throughput",
|
"sglang:gen_throughput",
|
||||||
"sglang:num_queue_reqs",
|
"sglang:num_queue_reqs",
|
||||||
"sglang:cache_hit_rate",
|
"sglang:cache_hit_rate",
|
||||||
|
"sglang:spec_accept_length",
|
||||||
"sglang:prompt_tokens_total",
|
"sglang:prompt_tokens_total",
|
||||||
"sglang:generation_tokens_total",
|
"sglang:generation_tokens_total",
|
||||||
"sglang:num_requests_total",
|
"sglang:num_requests_total",
|
||||||
"sglang:time_to_first_token_seconds",
|
"sglang:time_to_first_token_seconds",
|
||||||
"sglang:time_per_output_token_seconds",
|
"sglang:time_per_output_token_seconds",
|
||||||
|
"sglang:inter_token_latency_seconds",
|
||||||
"sglang:e2e_request_latency_seconds",
|
"sglang:e2e_request_latency_seconds",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class TestDeepseekV3MTP(unittest.TestCase):
|
|||||||
metrics = run_eval_few_shot_gsm8k(args)
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
print(metrics)
|
print(metrics)
|
||||||
|
|
||||||
self.assertGreater(metrics["accuracy"], 0.62)
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
91
test/srt/test_penalty.py
Normal file
91
test/srt/test_penalty.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import json
|
||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPenalty(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def run_decode(self, sampling_params):
|
||||||
|
return_logprob = True
|
||||||
|
top_logprobs_num = 5
|
||||||
|
return_text = True
|
||||||
|
n = 1
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
# prompt that is supposed to generate < 32 tokens
|
||||||
|
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
"n": n,
|
||||||
|
**sampling_params,
|
||||||
|
},
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"return_text_in_logprobs": return_text,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
print(json.dumps(response.json()))
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
self.run_decode({})
|
||||||
|
|
||||||
|
def test_frequency_penalty(self):
|
||||||
|
self.run_decode({"frequency_penalty": 2})
|
||||||
|
|
||||||
|
def test_min_new_tokens(self):
|
||||||
|
self.run_decode({"min_new_tokens": 16})
|
||||||
|
|
||||||
|
def test_presence_penalty(self):
|
||||||
|
self.run_decode({"presence_penalty": 2})
|
||||||
|
|
||||||
|
def test_mixed(self):
|
||||||
|
args = [
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{"frequency_penalty": 2},
|
||||||
|
{"min_new_tokens": 16},
|
||||||
|
{"presence_penalty": 1},
|
||||||
|
{"frequency_penalty": 0.2},
|
||||||
|
{"min_new_tokens": 8},
|
||||||
|
{"presence_penalty": 0.4},
|
||||||
|
{"presence_penalty": 0.4, "frequency_penalty": 2},
|
||||||
|
{"min_new_tokens": 12, "frequency_penalty": 2},
|
||||||
|
]
|
||||||
|
random.shuffle(args * 5)
|
||||||
|
with ThreadPoolExecutor(8) as executor:
|
||||||
|
list(executor.map(self.run_decode, args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=3)
|
||||||
@@ -70,7 +70,10 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
|
|
||||||
first_rid = None
|
first_rid = None
|
||||||
outputs_from_session = []
|
outputs_from_session = []
|
||||||
|
logprobs_from_session = []
|
||||||
|
cur_logprob_start_len = 0
|
||||||
for i, chunk_ids in enumerate(chunks_ids):
|
for i, chunk_ids in enumerate(chunks_ids):
|
||||||
|
max_new_tokens = gen_len if i > 0 else 1 # prefill only for the first chunk
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={
|
json={
|
||||||
@@ -83,12 +86,12 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": (
|
"max_new_tokens": max_new_tokens,
|
||||||
gen_len if i > 0 else 1
|
|
||||||
), # prefill only for the first chunk
|
|
||||||
"no_stop_trim": True,
|
"no_stop_trim": True,
|
||||||
"skip_special_tokens": False,
|
"skip_special_tokens": False,
|
||||||
},
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
"logprob_start_len": cur_logprob_start_len - 1,
|
||||||
},
|
},
|
||||||
).json()
|
).json()
|
||||||
rid = response["meta_info"]["id"]
|
rid = response["meta_info"]["id"]
|
||||||
@@ -96,8 +99,39 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
first_rid = rid
|
first_rid = rid
|
||||||
if i > 0:
|
if i > 0:
|
||||||
outputs_from_session.append(response["text"])
|
outputs_from_session.append(response["text"])
|
||||||
|
logprobs_from_session.extend(
|
||||||
|
[
|
||||||
|
round(sublist[0], 2)
|
||||||
|
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cur_logprob_start_len += len(chunk_ids) + max_new_tokens
|
||||||
|
|
||||||
|
# query with a logprob_start_len longer than the request, should see error
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": chunk_ids,
|
||||||
|
"session_params": {
|
||||||
|
"id": session_id,
|
||||||
|
"rid": rid,
|
||||||
|
"offset": -1,
|
||||||
|
"replace": True,
|
||||||
|
},
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"no_stop_trim": True,
|
||||||
|
"skip_special_tokens": False,
|
||||||
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
"logprob_start_len": cur_logprob_start_len + len(chunk_ids),
|
||||||
|
},
|
||||||
|
).json()
|
||||||
|
assert "Request with a lower logprob_start_len" in response["error"]["message"]
|
||||||
|
|
||||||
# backtrack to the first request and regenerate
|
# backtrack to the first request and regenerate
|
||||||
|
cur_logprob_start_len = 0
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={
|
json={
|
||||||
@@ -114,9 +148,17 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
"no_stop_trim": True,
|
"no_stop_trim": True,
|
||||||
"skip_special_tokens": False,
|
"skip_special_tokens": False,
|
||||||
},
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
"logprob_start_len": cur_logprob_start_len,
|
||||||
},
|
},
|
||||||
).json()
|
).json()
|
||||||
outputs_from_session.append(response["text"])
|
outputs_from_session.append(response["text"])
|
||||||
|
logprobs_from_session.extend(
|
||||||
|
[
|
||||||
|
round(sublist[0], 2)
|
||||||
|
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
|
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@@ -135,6 +177,7 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
"no_stop_trim": True,
|
"no_stop_trim": True,
|
||||||
"skip_special_tokens": False,
|
"skip_special_tokens": False,
|
||||||
},
|
},
|
||||||
|
"return_logprob": True,
|
||||||
},
|
},
|
||||||
).json()
|
).json()
|
||||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||||
@@ -162,6 +205,7 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
"no_stop_trim": True,
|
"no_stop_trim": True,
|
||||||
"skip_special_tokens": False,
|
"skip_special_tokens": False,
|
||||||
},
|
},
|
||||||
|
"return_logprob": True,
|
||||||
},
|
},
|
||||||
).json()
|
).json()
|
||||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||||
@@ -172,6 +216,7 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
input_ids_first_req = None
|
input_ids_first_req = None
|
||||||
input_ids = []
|
input_ids = []
|
||||||
outputs_normal = []
|
outputs_normal = []
|
||||||
|
logprobs_normal = []
|
||||||
for i, chunk_ids in enumerate(chunks_ids):
|
for i, chunk_ids in enumerate(chunks_ids):
|
||||||
input_ids += chunk_ids
|
input_ids += chunk_ids
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@@ -186,6 +231,7 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
"no_stop_trim": True,
|
"no_stop_trim": True,
|
||||||
"skip_special_tokens": False,
|
"skip_special_tokens": False,
|
||||||
},
|
},
|
||||||
|
"return_logprob": True,
|
||||||
},
|
},
|
||||||
).json()
|
).json()
|
||||||
if i > 0:
|
if i > 0:
|
||||||
@@ -194,6 +240,12 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
output_ids = output_ids[1:]
|
output_ids = output_ids[1:]
|
||||||
input_ids += output_ids[:-1]
|
input_ids += output_ids[:-1]
|
||||||
outputs_normal.append(response["text"])
|
outputs_normal.append(response["text"])
|
||||||
|
logprobs_normal.extend(
|
||||||
|
[
|
||||||
|
round(sublist[0], 2)
|
||||||
|
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||||
|
]
|
||||||
|
)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
input_ids_first_req = input_ids.copy()
|
input_ids_first_req = input_ids.copy()
|
||||||
|
|
||||||
@@ -208,17 +260,31 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
"no_stop_trim": True,
|
"no_stop_trim": True,
|
||||||
"skip_special_tokens": False,
|
"skip_special_tokens": False,
|
||||||
},
|
},
|
||||||
|
"return_logprob": True,
|
||||||
},
|
},
|
||||||
).json()
|
).json()
|
||||||
outputs_normal.append(response["text"])
|
outputs_normal.append(response["text"])
|
||||||
|
logprobs_normal.extend(
|
||||||
|
[
|
||||||
|
round(sublist[0], 2)
|
||||||
|
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
print("outputs from chunked queries with session control:")
|
print("outputs from chunked queries with session control:")
|
||||||
print(outputs_from_session)
|
print(outputs_from_session)
|
||||||
print("outputs from normal queries:")
|
print("outputs from normal queries:")
|
||||||
print(outputs_normal)
|
print(outputs_normal)
|
||||||
assert (
|
assert outputs_from_session == outputs_normal
|
||||||
outputs_from_session == outputs_normal
|
print("logprobs from chunked queries with session control:")
|
||||||
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
|
print(logprobs_from_session)
|
||||||
|
print("logprobs from normal queries:")
|
||||||
|
print(logprobs_normal)
|
||||||
|
assert len(logprobs_from_session) == len(
|
||||||
|
logprobs_normal
|
||||||
|
), "logprobs must have equal length"
|
||||||
|
for a, b in zip(logprobs_from_session, logprobs_normal):
|
||||||
|
assert abs(a - b) <= 0.1, f"logprobs {a} and {b} differ by more than 0.1"
|
||||||
|
|
||||||
async def async_generate(self, payload):
|
async def async_generate(self, payload):
|
||||||
url = self.base_url + "/generate"
|
url = self.base_url + "/generate"
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
"""
|
||||||
|
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
|
||||||
|
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_stream
|
||||||
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -12,42 +17,26 @@ from sglang.test.test_utils import (
|
|||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
_server_process = None
|
|
||||||
_base_url = None
|
|
||||||
_tokenizer = None
|
|
||||||
|
|
||||||
|
|
||||||
def setUpModule():
|
|
||||||
"""
|
|
||||||
Launch the server once before all tests and initialize the tokenizer.
|
|
||||||
"""
|
|
||||||
global _server_process, _base_url, _tokenizer
|
|
||||||
_server_process = popen_launch_server(
|
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
other_args=["--skip-tokenizer-init"],
|
|
||||||
)
|
|
||||||
_base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
|
|
||||||
_tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
|
|
||||||
)
|
|
||||||
print(">>> setUpModule: Server launched, tokenizer ready")
|
|
||||||
|
|
||||||
|
|
||||||
def tearDownModule():
|
|
||||||
"""
|
|
||||||
Terminate the server once after all tests have completed.
|
|
||||||
"""
|
|
||||||
global _server_process
|
|
||||||
if _server_process is not None:
|
|
||||||
kill_process_tree(_server_process.pid)
|
|
||||||
_server_process = None
|
|
||||||
print(">>> tearDownModule: Server terminated")
|
|
||||||
|
|
||||||
|
|
||||||
class TestSkipTokenizerInit(unittest.TestCase):
|
class TestSkipTokenizerInit(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--skip-tokenizer-init", "--stream-output"],
|
||||||
|
)
|
||||||
|
cls.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
def run_decode(
|
def run_decode(
|
||||||
self,
|
self,
|
||||||
prompt_text="The capital of France is",
|
prompt_text="The capital of France is",
|
||||||
@@ -56,19 +45,19 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
top_logprobs_num=0,
|
top_logprobs_num=0,
|
||||||
n=1,
|
n=1,
|
||||||
):
|
):
|
||||||
input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][
|
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
|
||||||
0
|
0
|
||||||
].tolist()
|
].tolist()
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
_base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={
|
json={
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0 if n == 1 else 0.5,
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
"n": n,
|
"n": n,
|
||||||
"stop_token_ids": [_tokenizer.eos_token_id],
|
"stop_token_ids": [self.tokenizer.eos_token_id],
|
||||||
},
|
},
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
@@ -83,13 +72,13 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
if item["meta_info"]["finish_reason"]["type"] == "stop":
|
if item["meta_info"]["finish_reason"]["type"] == "stop":
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
item["meta_info"]["finish_reason"]["matched"],
|
item["meta_info"]["finish_reason"]["matched"],
|
||||||
_tokenizer.eos_token_id,
|
self.tokenizer.eos_token_id,
|
||||||
)
|
)
|
||||||
elif item["meta_info"]["finish_reason"]["type"] == "length":
|
elif item["meta_info"]["finish_reason"]["type"] == "length":
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(item["token_ids"]), item["meta_info"]["completion_tokens"]
|
len(item["output_ids"]), item["meta_info"]["completion_tokens"]
|
||||||
)
|
)
|
||||||
self.assertEqual(len(item["token_ids"]), max_new_tokens)
|
self.assertEqual(len(item["output_ids"]), max_new_tokens)
|
||||||
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
|
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
@@ -113,6 +102,63 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
|
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
|
|
||||||
|
def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||||
|
max_new_tokens = 32
|
||||||
|
input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is
|
||||||
|
requests.post(self.base_url + "/flush_cache")
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"n": n,
|
||||||
|
"stop_token_ids": [119690],
|
||||||
|
},
|
||||||
|
"stream": False,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ret = response.json()
|
||||||
|
print(json.dumps(ret))
|
||||||
|
output_ids = ret["output_ids"]
|
||||||
|
|
||||||
|
requests.post(self.base_url + "/flush_cache")
|
||||||
|
response_stream = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"n": n,
|
||||||
|
"stop_token_ids": [119690],
|
||||||
|
},
|
||||||
|
"stream": True,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ret = response.json()
|
||||||
|
output_ids = ret["output_ids"]
|
||||||
|
print("output from non-streaming request:")
|
||||||
|
print(output_ids)
|
||||||
|
|
||||||
|
response_stream_json = []
|
||||||
|
for line in response_stream.iter_lines():
|
||||||
|
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
||||||
|
response_stream_json.append(json.loads(line[6:]))
|
||||||
|
out_stream_ids = []
|
||||||
|
for x in response_stream_json:
|
||||||
|
out_stream_ids += x["output_ids"]
|
||||||
|
print("output from streaming request:")
|
||||||
|
print(out_stream_ids)
|
||||||
|
assert output_ids == out_stream_ids
|
||||||
|
|
||||||
def test_simple_decode(self):
|
def test_simple_decode(self):
|
||||||
self.run_decode()
|
self.run_decode()
|
||||||
|
|
||||||
@@ -126,6 +172,9 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
|||||||
def test_eos_behavior(self):
|
def test_eos_behavior(self):
|
||||||
self.run_decode(max_new_tokens=256)
|
self.run_decode(max_new_tokens=256)
|
||||||
|
|
||||||
|
def test_simple_decode_stream(self):
|
||||||
|
self.run_decode_stream()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -20,6 +21,7 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
|
run_logprob_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -35,7 +37,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
other_args=(
|
other_args=(
|
||||||
"--enable-custom-logit-processor",
|
"--enable-custom-logit-processor",
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
"0.8",
|
"0.7",
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"8",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -131,7 +135,7 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
for i, res in enumerate(response_json):
|
for i, res in enumerate(response_json):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res["meta_info"]["prompt_tokens"],
|
res["meta_info"]["prompt_tokens"],
|
||||||
logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]),
|
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
|
||||||
)
|
)
|
||||||
assert prompts[i].endswith(
|
assert prompts[i].endswith(
|
||||||
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
||||||
@@ -235,83 +239,15 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
|
|
||||||
diff = np.abs(output_logprobs - output_logprobs_score)
|
diff = np.abs(output_logprobs - output_logprobs_score)
|
||||||
max_diff = np.max(diff)
|
max_diff = np.max(diff)
|
||||||
self.assertLess(max_diff, 0.25)
|
self.assertLess(max_diff, 0.35)
|
||||||
|
|
||||||
def run_logprob_check(self, arg):
|
|
||||||
(
|
|
||||||
input_len,
|
|
||||||
output_len,
|
|
||||||
temperature,
|
|
||||||
logprob_start_len,
|
|
||||||
return_logprob,
|
|
||||||
top_logprobs_num,
|
|
||||||
) = arg
|
|
||||||
input_ids = list(range(input_len))
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
self.base_url + "/generate",
|
|
||||||
json={
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"sampling_params": {
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_new_tokens": output_len,
|
|
||||||
},
|
|
||||||
"return_logprob": return_logprob,
|
|
||||||
"logprob_start_len": logprob_start_len,
|
|
||||||
"top_logprobs_num": top_logprobs_num,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response_json = response.json()
|
|
||||||
|
|
||||||
res = response_json
|
|
||||||
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
|
|
||||||
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
|
|
||||||
|
|
||||||
# Test the number of tokens are correct
|
|
||||||
if return_logprob:
|
|
||||||
# This is because if logprob_start_len == 0, we added a padding for the first token.
|
|
||||||
# In other cases, we do not add the padding
|
|
||||||
delta = 0 if logprob_start_len == 0 else 1
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
len(res["meta_info"]["input_token_logprobs"])
|
|
||||||
+ logprob_start_len
|
|
||||||
+ delta,
|
|
||||||
res["meta_info"]["prompt_tokens"],
|
|
||||||
)
|
|
||||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
|
|
||||||
|
|
||||||
if top_logprobs_num:
|
|
||||||
self.assertEqual(
|
|
||||||
len(res["meta_info"]["input_top_logprobs"])
|
|
||||||
+ logprob_start_len
|
|
||||||
+ delta,
|
|
||||||
res["meta_info"]["prompt_tokens"],
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
len(res["meta_info"]["output_top_logprobs"]), output_len
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(output_len):
|
|
||||||
self.assertEqual(
|
|
||||||
len(res["meta_info"]["output_top_logprobs"][i]),
|
|
||||||
top_logprobs_num,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test the top-1 tokens are the same as output tokens if temperature == 0
|
|
||||||
if temperature == 0:
|
|
||||||
self.assertListEqual(
|
|
||||||
res["meta_info"]["output_token_logprobs"][i],
|
|
||||||
res["meta_info"]["output_top_logprobs"][i][0],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_logprob_mixed(self):
|
def test_logprob_mixed(self):
|
||||||
args = []
|
args = []
|
||||||
temperature = 0
|
temperature = 0
|
||||||
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
|
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
|
||||||
for input_len in [1000, 2000]:
|
for input_len in [1000, 5000, 10000, 50000]:
|
||||||
for output_len in [4, 8]:
|
for output_len in [4, 8]:
|
||||||
for logprob_start_len in [0, 500, 1000]:
|
for logprob_start_len in [0, 500, 2500, 5000, 25000]:
|
||||||
for return_logprob in [True, False]:
|
for return_logprob in [True, False]:
|
||||||
for top_logprobs_num in [0, 5]:
|
for top_logprobs_num in [0, 5]:
|
||||||
|
|
||||||
@@ -331,8 +267,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
|
|
||||||
random.shuffle(args)
|
random.shuffle(args)
|
||||||
|
|
||||||
|
func = partial(run_logprob_check, self)
|
||||||
with ThreadPoolExecutor(8) as executor:
|
with ThreadPoolExecutor(8) as executor:
|
||||||
list(executor.map(self.run_logprob_check, args))
|
list(executor.map(func, args))
|
||||||
|
|
||||||
def test_logprob_grammar(self):
|
def test_logprob_grammar(self):
|
||||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||||
@@ -427,6 +364,77 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
|
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def run_stateful_custom_logit_processor(
|
||||||
|
self, first_token_id: int | None, delay: int = 2
|
||||||
|
):
|
||||||
|
"""Test custom logit processor with custom params and state.
|
||||||
|
|
||||||
|
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
|
||||||
|
If first_token_id is None, the custom logit processor won't be passed in.
|
||||||
|
"""
|
||||||
|
|
||||||
|
custom_params = {"token_id": first_token_id, "delay": 2}
|
||||||
|
|
||||||
|
class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
|
||||||
|
"""A dummy logit processor that changes the logits to always
|
||||||
|
sample the given token id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, logits, custom_param_list):
|
||||||
|
assert logits.shape[0] == len(custom_param_list)
|
||||||
|
|
||||||
|
for i, param_dict in enumerate(custom_param_list):
|
||||||
|
if param_dict["delay"] > 0:
|
||||||
|
param_dict["delay"] -= 1
|
||||||
|
continue
|
||||||
|
if param_dict["delay"] == 0:
|
||||||
|
param_dict["delay"] -= 1
|
||||||
|
force_token = param_dict["token_id"]
|
||||||
|
else:
|
||||||
|
output_ids = param_dict["__req__"].output_ids
|
||||||
|
force_token = output_ids[-1] + 1
|
||||||
|
# Mask all other tokens
|
||||||
|
logits[i, :] = -float("inf")
|
||||||
|
# Assign highest probability to the specified token
|
||||||
|
logits[i, force_token] = 0.0
|
||||||
|
return logits
|
||||||
|
|
||||||
|
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||||
|
|
||||||
|
# Base case json data to be posted to the server.
|
||||||
|
base_json = {
|
||||||
|
"text": prompts,
|
||||||
|
"sampling_params": {"temperature": 0.0},
|
||||||
|
"return_logprob": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Custom json data with custom logit processor and params.
|
||||||
|
custom_json = base_json.copy()
|
||||||
|
# Only set the custom logit processor if target_token_id is not None.
|
||||||
|
if first_token_id is not None:
|
||||||
|
custom_json["custom_logit_processor"] = (
|
||||||
|
DeterministicStatefulLogitProcessor().to_str()
|
||||||
|
)
|
||||||
|
custom_json["sampling_params"]["custom_params"] = custom_params
|
||||||
|
|
||||||
|
custom_response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json=custom_json,
|
||||||
|
).json()
|
||||||
|
|
||||||
|
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
|
||||||
|
sampled_tokens = [x[1] for x in output_token_logprobs]
|
||||||
|
# The logit processor should always sample the given token as the logits is deterministic.
|
||||||
|
if first_token_id is not None:
|
||||||
|
self.assertTrue(
|
||||||
|
all(
|
||||||
|
x == custom_params["token_id"] + k
|
||||||
|
for k, x in enumerate(sampled_tokens[custom_params["delay"] :])
|
||||||
|
),
|
||||||
|
# Print the detailed test case info if the test fails.
|
||||||
|
f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}",
|
||||||
|
)
|
||||||
|
|
||||||
def test_custom_logit_processor(self):
|
def test_custom_logit_processor(self):
|
||||||
"""Test custom logit processor with a single request."""
|
"""Test custom logit processor with a single request."""
|
||||||
self.run_custom_logit_processor(target_token_id=5)
|
self.run_custom_logit_processor(target_token_id=5)
|
||||||
@@ -438,6 +446,19 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||||
|
|
||||||
|
def test_stateful_custom_logit_processor(self):
|
||||||
|
"""Test custom logit processor with a single request."""
|
||||||
|
self.run_stateful_custom_logit_processor(first_token_id=5)
|
||||||
|
|
||||||
|
def test_stateful_custom_logit_processor_batch_mixed(self):
|
||||||
|
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||||
|
target_token_ids = list(range(32)) + [None] * 16
|
||||||
|
random.shuffle(target_token_ids)
|
||||||
|
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||||
|
list(
|
||||||
|
executor.map(self.run_stateful_custom_logit_processor, target_token_ids)
|
||||||
|
)
|
||||||
|
|
||||||
def test_cache_tokens(self):
|
def test_cache_tokens(self):
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
@@ -476,6 +497,21 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
version = response_json["version"]
|
version = response_json["version"]
|
||||||
self.assertIsInstance(version, str)
|
self.assertIsInstance(version, str)
|
||||||
|
|
||||||
|
def test_get_server_info_concurrent(self):
|
||||||
|
"""Make sure the concurrent get_server_info doesn't crash the server."""
|
||||||
|
tp = ThreadPoolExecutor(max_workers=30)
|
||||||
|
|
||||||
|
def s():
|
||||||
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
server_info.json()
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
for _ in range(4):
|
||||||
|
futures.append(tp.submit(s))
|
||||||
|
|
||||||
|
for f in futures:
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -168,9 +168,9 @@ def _run_subprocess(
|
|||||||
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
hf_outputs = HFRunner.forward_generation_raw(
|
hf_outputs = HFRunner.forward_generation_raw(
|
||||||
|
base_model=hf_model,
|
||||||
prompts=_PROMPTS,
|
prompts=_PROMPTS,
|
||||||
max_new_tokens=_MAX_NEW_TOKENS,
|
max_new_tokens=_MAX_NEW_TOKENS,
|
||||||
base_model=hf_model,
|
|
||||||
tokenizer=hf_tokenizer,
|
tokenizer=hf_tokenizer,
|
||||||
lora_paths=None,
|
lora_paths=None,
|
||||||
torch_dtype=_TORCH_DTYPE,
|
torch_dtype=_TORCH_DTYPE,
|
||||||
|
|||||||
Reference in New Issue
Block a user