Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -96,7 +96,10 @@ dev_cpu = ["sglang[all_cpu]", "sglang[test]"]
|
||||
"Bug Tracker" = "https://github.com/sgl-project/sglang/issues"
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"sglang" = ["srt/layers/moe/fused_moe_triton/configs/*.json", "srt/layers/quantization/configs/*.json"]
|
||||
"sglang" = [
|
||||
"srt/layers/moe/fused_moe_triton/configs/*.json",
|
||||
"srt/layers/quantization/configs/*.json",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
exclude = [
|
||||
|
||||
@@ -8,8 +8,10 @@
|
||||
- `bench_one_batch.py`: Benchmark the latency of running a single static batch without a server.
|
||||
- `bench_one_batch_server.py`: Benchmark the latency of running a single batch with a server.
|
||||
- `bench_serving.py`: Benchmark online serving with dynamic requests.
|
||||
- `check_env.py`: Check the environment variables.
|
||||
- `check_env.py`: Check the environment variables and dependencies.
|
||||
- `global_config.py`: The global configs and constants.
|
||||
- `launch_server.py`: The entry point for launching the local server.
|
||||
- `llama3_eval.py`: Evaluation of Llama 3 using the Meta Llama dataset.
|
||||
- `profiler.py`: Profile a running server.
|
||||
- `utils.py`: Common utilities.
|
||||
- `version.py`: Version info.
|
||||
|
||||
@@ -56,6 +56,7 @@ class BenchArgs:
|
||||
profile: bool = False
|
||||
skip_warmup: bool = False
|
||||
do_not_exit: bool = False
|
||||
prompt_suffix: str = ""
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -177,6 +178,12 @@ class BenchArgs:
|
||||
action="store_true",
|
||||
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-suffix",
|
||||
type=str,
|
||||
default="",
|
||||
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
@@ -216,6 +223,10 @@ def throughput_test_once(
|
||||
]
|
||||
|
||||
if profile:
|
||||
assert (
|
||||
"SGLANG_TORCH_PROFILER_DIR" in os.environ
|
||||
), "Please set SGLANG_TORCH_PROFILER_DIR."
|
||||
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
|
||||
backend.start_profile()
|
||||
|
||||
st = time.perf_counter()
|
||||
@@ -229,6 +240,8 @@ def throughput_test_once(
|
||||
if backend_name == "runtime":
|
||||
gen_out = json.loads(gen_out)
|
||||
|
||||
server_info = backend.get_server_info()
|
||||
|
||||
measurement_results["total_latency"] = latency
|
||||
measurement_results["total_output_tokens"] = sum(
|
||||
o["meta_info"]["completion_tokens"] for o in gen_out
|
||||
@@ -246,6 +259,7 @@ def throughput_test_once(
|
||||
measurement_results["total_input_tokens"]
|
||||
+ measurement_results["total_output_tokens"]
|
||||
) / latency
|
||||
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
|
||||
|
||||
return measurement_results
|
||||
|
||||
@@ -361,6 +375,11 @@ def throughput_test(
|
||||
print(
|
||||
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Last generation throughput (tok/s):", result["last_gen_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Request throughput (req/s):", result["request_throughput"]
|
||||
|
||||
@@ -8,7 +8,6 @@ Usage:
|
||||
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
||||
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -71,6 +70,10 @@ def remove_prefix(text: str, prefix: str) -> str:
|
||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||
|
||||
|
||||
def remove_suffix(text: str, suffix: str) -> str:
|
||||
return text[: -len(suffix)] if text.endswith(suffix) else text
|
||||
|
||||
|
||||
def get_auth_headers() -> Dict[str, str]:
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
@@ -79,7 +82,7 @@ def get_auth_headers() -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
# trt llm not support ignore_eos
|
||||
# trt llm does not support ignore_eos
|
||||
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
||||
async def async_request_trt_llm(
|
||||
request_func_input: RequestFuncInput,
|
||||
@@ -179,6 +182,7 @@ async def async_request_openai_completions(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
output_len = request_func_input.output_len
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
@@ -215,11 +219,14 @@ async def async_request_openai_completions(
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += data["choices"][0]["text"]
|
||||
output_len = data.get("usage", {}).get(
|
||||
"completion_tokens", output_len
|
||||
)
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.output_len = request_func_input.output_len
|
||||
output.output_len = output_len
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@@ -339,9 +346,11 @@ async def async_request_sglang_generate(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
output_len = request_func_input.output_len
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
last_output_len = 0
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
@@ -365,6 +374,9 @@ async def async_request_sglang_generate(
|
||||
# want to check a token was generated
|
||||
if data["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
generated_text = data["text"]
|
||||
output_len = data["meta_info"]["completion_tokens"]
|
||||
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
@@ -372,7 +384,13 @@ async def async_request_sglang_generate(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
num_new_tokens = output_len - last_output_len
|
||||
if num_new_tokens == 0:
|
||||
continue
|
||||
adjust_itl = (
|
||||
timestamp - most_recent_timestamp
|
||||
) / num_new_tokens
|
||||
output.itl.extend([adjust_itl] * num_new_tokens)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text = data["text"]
|
||||
@@ -380,7 +398,7 @@ async def async_request_sglang_generate(
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.output_len = request_func_input.output_len
|
||||
output.output_len = output_len
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@@ -388,6 +406,7 @@ async def async_request_sglang_generate(
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
print(f"{output.error=}")
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
@@ -461,6 +480,7 @@ def get_dataset(args, tokenizer):
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
context_len=args.sharegpt_context_len,
|
||||
prompt_suffix=args.prompt_suffix,
|
||||
apply_chat_template=args.apply_chat_template,
|
||||
)
|
||||
elif args.dataset_name == "random":
|
||||
@@ -521,7 +541,9 @@ class BenchmarkMetrics:
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
p95_itl_ms: float
|
||||
p99_itl_ms: float
|
||||
max_itl_ms: float
|
||||
mean_e2e_latency_ms: float
|
||||
median_e2e_latency_ms: float
|
||||
std_e2e_latency_ms: float
|
||||
@@ -572,6 +594,7 @@ def sample_sharegpt_requests(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
context_len: Optional[int] = None,
|
||||
prompt_suffix: Optional[str] = "",
|
||||
apply_chat_template=False,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
@@ -584,11 +607,19 @@ def sample_sharegpt_requests(
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
dataset = [
|
||||
data
|
||||
for data in dataset
|
||||
if len(data.get("conversations", data.get("conversation", []))) >= 2
|
||||
]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
(
|
||||
data.get("conversations", data.get("conversation", []))[0]["value"],
|
||||
data.get("conversations", data.get("conversation", []))[1]["value"],
|
||||
)
|
||||
for data in dataset
|
||||
]
|
||||
|
||||
@@ -603,6 +634,8 @@ def sample_sharegpt_requests(
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
if prompt_suffix:
|
||||
prompt = prompt
|
||||
|
||||
if apply_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
@@ -666,10 +699,17 @@ def sample_random_requests(
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
dataset = [
|
||||
data
|
||||
for data in dataset
|
||||
if len(data.get("conversations", data.get("conversation", []))) >= 2
|
||||
]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
(
|
||||
data.get("conversations", data.get("conversation", []))[0]["value"],
|
||||
data.get("conversations", data.get("conversation", []))[1]["value"],
|
||||
)
|
||||
for data in dataset
|
||||
]
|
||||
# Shuffle the dataset.
|
||||
@@ -895,7 +935,9 @@ def calculate_metrics(
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
std_itl_ms=np.std(itls or 0) * 1000,
|
||||
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
|
||||
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
||||
max_itl_ms=np.max(itls or 0) * 1000,
|
||||
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
|
||||
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
||||
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
|
||||
@@ -919,6 +961,7 @@ async def benchmark(
|
||||
lora_name: str,
|
||||
extra_request_body: Dict[str, Any],
|
||||
profile: bool,
|
||||
pd_seperated: bool = False,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@@ -1004,6 +1047,17 @@ async def benchmark(
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
if "sglang" in backend:
|
||||
server_info = requests.get(base_url + "/get_server_info")
|
||||
if pd_seperated:
|
||||
accept_length = server_info.json()["decode"][0].get(
|
||||
"avg_spec_accept_length", None
|
||||
)
|
||||
else:
|
||||
accept_length = server_info.json().get("avg_spec_accept_length", None)
|
||||
else:
|
||||
accept_length = None
|
||||
|
||||
# Compute metrics and print results
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
metrics, output_lens = calculate_metrics(
|
||||
@@ -1053,6 +1107,8 @@ async def benchmark(
|
||||
)
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
|
||||
if accept_length:
|
||||
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
|
||||
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
||||
@@ -1066,16 +1122,12 @@ async def benchmark(
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print(
|
||||
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
|
||||
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
|
||||
print("=" * 50)
|
||||
|
||||
if (
|
||||
@@ -1117,8 +1169,10 @@ async def benchmark(
|
||||
"mean_itl_ms": metrics.mean_itl_ms,
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p95_itl_ms": metrics.p95_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"concurrency": metrics.concurrency,
|
||||
"accept_length": accept_length,
|
||||
}
|
||||
else:
|
||||
print(f"Error running benchmark for request rate: {request_rate}")
|
||||
@@ -1151,14 +1205,6 @@ async def benchmark(
|
||||
return result
|
||||
|
||||
|
||||
def parse_request_rate_range(request_rate_range):
|
||||
if len(request_rate_range.split(",")) == 3:
|
||||
start, stop, step = map(int, request_rate_range.split(","))
|
||||
return list(range(start, stop, step))
|
||||
else:
|
||||
return list(map(int, request_rate_range.split(",")))
|
||||
|
||||
|
||||
def check_chat_template(model_path):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
@@ -1168,6 +1214,12 @@ def check_chat_template(model_path):
|
||||
return False
|
||||
|
||||
|
||||
def set_global_args(args_: argparse.Namespace):
|
||||
"""Set the global args."""
|
||||
global args
|
||||
args = args_
|
||||
|
||||
|
||||
def run_benchmark(args_: argparse.Namespace):
|
||||
global args
|
||||
args = args_
|
||||
@@ -1176,6 +1228,8 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if not hasattr(args, "max_concurrency"):
|
||||
args.max_concurrency = None
|
||||
|
||||
print(f"benchmark_args={args}")
|
||||
|
||||
# Set global environments
|
||||
set_ulimit()
|
||||
random.seed(args.seed)
|
||||
@@ -1272,49 +1326,26 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id)
|
||||
|
||||
input_requests = get_dataset(args, tokenizer)
|
||||
|
||||
if not args.multi:
|
||||
return asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
return asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
pd_seperated=args.pd_seperated,
|
||||
)
|
||||
else:
|
||||
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
|
||||
request_rates = parse_request_rate_range(args.request_rate_range)
|
||||
|
||||
for rate in request_rates:
|
||||
asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
@@ -1428,17 +1459,6 @@ if __name__ == "__main__":
|
||||
"actual request rate may be lower than specified with --request-rate, "
|
||||
"if the server is not processing requests fast enough to keep up.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multi",
|
||||
action="store_true",
|
||||
help="Use request rate range rather than single value.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate-range",
|
||||
type=str,
|
||||
default="2,34,2",
|
||||
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
|
||||
)
|
||||
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
||||
parser.add_argument(
|
||||
"--disable-tqdm",
|
||||
@@ -1485,6 +1505,17 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="The name of LoRA adapter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-suffix",
|
||||
type=str,
|
||||
default="",
|
||||
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pd-seperated",
|
||||
action="store_true",
|
||||
help="Benchmark PD disaggregation server",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||
group.add_argument(
|
||||
|
||||
@@ -34,11 +34,9 @@ class GlobalConfig:
|
||||
self.skip_special_tokens_in_output = True
|
||||
self.spaces_between_special_tokens_in_out = True
|
||||
|
||||
# Interpreter optimization configs
|
||||
# Language frontend interpreter optimization configs
|
||||
self.enable_precache_with_tracing = True
|
||||
self.enable_parallel_encoding = True
|
||||
|
||||
self.enable_flashinfer_mla = False
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
@@ -329,7 +329,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
|
||||
def compute_normalized_prompt_logprobs(input_logprobs):
|
||||
values = [x[0] for x in input_logprobs if x[0]]
|
||||
return sum(values) / len(values)
|
||||
try:
|
||||
return sum(values) / len(values)
|
||||
except TypeError:
|
||||
print(f"{input_logprobs=}", flush=True)
|
||||
print(f"{input_logprobs[0]=}", flush=True)
|
||||
exit(-1)
|
||||
|
||||
|
||||
class Runtime:
|
||||
|
||||
@@ -21,6 +21,7 @@ class LoadFormat(str, enum.Enum):
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
LAYERED = "layered"
|
||||
JAX = "jax"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,13 +43,15 @@ class LoadConfig:
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
|
||||
decryption_key_file: If set, decrypts the output files with a password read
|
||||
from this file (after PBKDF2).
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
|
||||
download_dir: Optional[str] = None
|
||||
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None
|
||||
decryption_key_file: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
|
||||
@@ -44,6 +44,7 @@ class ModelConfig:
|
||||
is_embedding: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
@@ -51,11 +52,16 @@ class ModelConfig:
|
||||
|
||||
# Parse args
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
kwargs = {}
|
||||
if override_config_file and override_config_file.strip():
|
||||
kwargs["_configuration_file"] = override_config_file.strip()
|
||||
|
||||
self.hf_config = get_config(
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
model_override_args=self.model_override_args,
|
||||
**kwargs,
|
||||
)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
|
||||
@@ -64,6 +70,9 @@ class ModelConfig:
|
||||
self.hf_config.architectures, is_embedding
|
||||
)
|
||||
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||
self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
|
||||
self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
|
||||
self.is_audio_model = is_audio_model(self.hf_config.architectures)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
|
||||
@@ -71,7 +80,9 @@ class ModelConfig:
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
if context_length is not None:
|
||||
if context_length > derived_context_len:
|
||||
if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"):
|
||||
if get_bool_env_var(
|
||||
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
|
||||
):
|
||||
logger.warning(
|
||||
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
||||
f"This may lead to incorrect model outputs or CUDA errors."
|
||||
@@ -416,6 +427,8 @@ def is_multimodal_model(model_architectures: List[str]):
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
or "Grok1VForCausalLM" in model_architectures
|
||||
or "Grok1AForCausalLM" in model_architectures
|
||||
or "MllamaForConditionalGeneration" in model_architectures
|
||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||
or "Qwen2_5_VLForConditionalGeneration" in model_architectures
|
||||
@@ -426,6 +439,18 @@ def is_multimodal_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_multimodal_gen_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_image_gen_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_audio_model(model_architectures: List[str]):
|
||||
return False
|
||||
|
||||
|
||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from xgrammar import (
|
||||
@@ -42,11 +42,16 @@ MAX_ROLLBACK_TOKENS = 200
|
||||
class XGrammarGrammar(BaseGrammarObject):
|
||||
|
||||
def __init__(
|
||||
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
|
||||
self,
|
||||
matcher: GrammarMatcher,
|
||||
vocab_size: int,
|
||||
ctx: CompiledGrammar,
|
||||
override_stop_tokens: Optional[Union[List[int], int]],
|
||||
) -> None:
|
||||
self.matcher = matcher
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx = ctx
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
self.finished = False
|
||||
|
||||
def accept_token(self, token: int):
|
||||
@@ -96,8 +101,14 @@ class XGrammarGrammar(BaseGrammarObject):
|
||||
apply_token_bitmask_inplace(logits, vocab_mask)
|
||||
|
||||
def copy(self):
|
||||
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
||||
matcher = GrammarMatcher(
|
||||
self.ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
)
|
||||
return XGrammarGrammar(
|
||||
matcher, self.vocab_size, self.ctx, self.override_stop_tokens
|
||||
)
|
||||
|
||||
|
||||
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
@@ -111,8 +122,11 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||
tokenizer, vocab_size=vocab_size
|
||||
)
|
||||
override_stop_tokens = None
|
||||
|
||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||
self.vocab_size = vocab_size
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||
|
||||
@@ -161,7 +175,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, ctx, self.override_stop_tokens)
|
||||
|
||||
def reset(self):
|
||||
if self.grammar_compiler:
|
||||
|
||||
@@ -121,6 +121,7 @@ class Engine:
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||
return_hidden_states: bool = False,
|
||||
@@ -142,6 +143,7 @@ class Engine:
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
lora_path=lora_path,
|
||||
modalities=modalities_list,
|
||||
custom_logit_processor=custom_logit_processor,
|
||||
@@ -179,6 +181,7 @@ class Engine:
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||
stream: bool = False,
|
||||
@@ -195,6 +198,7 @@ class Engine:
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
lora_path=lora_path,
|
||||
stream=stream,
|
||||
custom_logit_processor=custom_logit_processor,
|
||||
@@ -226,15 +230,22 @@ class Engine:
|
||||
kill_process_tree(os.getpid(), include_parent=False)
|
||||
|
||||
def start_profile(self):
|
||||
self.tokenizer_manager.start_profile()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.tokenizer_manager.start_profile())
|
||||
|
||||
def stop_profile(self):
|
||||
self.tokenizer_manager.stop_profile()
|
||||
|
||||
def get_server_info(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
internal_states = loop.run_until_complete(
|
||||
self.tokenizer_manager.get_internal_state()
|
||||
)
|
||||
|
||||
return {
|
||||
**dataclasses.asdict(self.tokenizer_manager.server_args), # server args
|
||||
**dataclasses.asdict(self.tokenizer_manager.server_args),
|
||||
**self.scheduler_info,
|
||||
**internal_states,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
@@ -323,6 +334,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
||||
|
||||
# Set prometheus env vars
|
||||
if server_args.enable_metrics:
|
||||
@@ -346,12 +358,23 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
)
|
||||
|
||||
def sigchld_handler(signum, frame):
|
||||
pid, exitcode = os.waitpid(0, os.WNOHANG)
|
||||
if exitcode != 0:
|
||||
logger.warning(
|
||||
"Child process unexpectedly failed with an exit code %d. pid=%d",
|
||||
exitcode,
|
||||
pid,
|
||||
)
|
||||
|
||||
signal.signal(signal.SIGCHLD, sigchld_handler)
|
||||
|
||||
# Register the signal handler.
|
||||
# The child processes will send SIGQUIT to this process when any error happens
|
||||
# This process then clean up the whole process tree
|
||||
def sigquit_handler(signum, frame):
|
||||
logger.error(
|
||||
"Received sigquit from a child proces. It usually means the child failed."
|
||||
"Received sigquit from a child process. It usually means the child failed."
|
||||
)
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
|
||||
@@ -25,11 +25,14 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Dict, Optional
|
||||
from typing import AsyncIterator, Callable, Dict, Optional
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
import requests
|
||||
import uvicorn
|
||||
@@ -49,8 +52,10 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
OpenSessionReqInput,
|
||||
ParseFunctionCallReq,
|
||||
ProfileReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
SetInternalStateReq,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
VertexGenerateReqInput,
|
||||
@@ -78,22 +83,13 @@ from sglang.srt.utils import (
|
||||
kill_process_tree,
|
||||
set_uvicorn_logging_configs,
|
||||
)
|
||||
from sglang.srt.warmup import execute_warmups
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.version import __version__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
# Fast API
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Store global states
|
||||
@dataclasses.dataclass
|
||||
@@ -110,6 +106,34 @@ def set_global_state(global_state: _GlobalState):
|
||||
_global_state = global_state
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(fast_api_app: FastAPI):
|
||||
server_args: ServerArgs = fast_api_app.server_args
|
||||
if server_args.warmups is not None:
|
||||
await execute_warmups(
|
||||
server_args.warmups.split(","), _global_state.tokenizer_manager
|
||||
)
|
||||
logger.info("Warmup ended")
|
||||
|
||||
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
|
||||
if warmup_thread is not None:
|
||||
warmup_thread.start()
|
||||
yield
|
||||
|
||||
|
||||
# Fast API
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
|
||||
@@ -123,24 +147,48 @@ async def health() -> Response:
|
||||
async def health_generate(request: Request) -> Response:
|
||||
"""Check the health of the inference server by generating one token."""
|
||||
|
||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
||||
rid = f"HEALTH_CHECK_{time.time()}"
|
||||
|
||||
if _global_state.tokenizer_manager.is_generation:
|
||||
if _global_state.tokenizer_manager.is_image_gen:
|
||||
raise NotImplementedError()
|
||||
elif _global_state.tokenizer_manager.is_generation:
|
||||
gri = GenerateReqInput(
|
||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
rid=rid,
|
||||
input_ids=[0],
|
||||
sampling_params=sampling_params,
|
||||
log_metrics=False,
|
||||
)
|
||||
else:
|
||||
gri = EmbeddingReqInput(
|
||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
)
|
||||
|
||||
try:
|
||||
async def gen():
|
||||
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
||||
break
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return Response(status_code=503)
|
||||
|
||||
tic = time.time()
|
||||
task = asyncio.create_task(gen())
|
||||
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
|
||||
await asyncio.sleep(1)
|
||||
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
||||
task.cancel()
|
||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||
return Response(status_code=200)
|
||||
|
||||
task.cancel()
|
||||
tic_time = time.strftime("%H:%M:%S", time.localtime(tic))
|
||||
last_receive_time = time.strftime(
|
||||
"%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp)
|
||||
)
|
||||
logger.error(
|
||||
f"Health check failed. Server couldn't get a response from detokenizer for last "
|
||||
f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. "
|
||||
f"last_heartbeat time: {last_receive_time}"
|
||||
)
|
||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
@@ -156,13 +204,21 @@ async def get_model_info():
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
internal_states = await _global_state.tokenizer_manager.get_internal_state()
|
||||
return {
|
||||
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
||||
**_global_state.scheduler_info,
|
||||
**internal_states,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
|
||||
async def set_internal_state(obj: SetInternalStateReq, request: Request):
|
||||
res = await _global_state.tokenizer_manager.set_internal_state(obj)
|
||||
return res
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
@@ -179,6 +235,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
) + b"\n\n"
|
||||
except ValueError as e:
|
||||
out = {"error": {"message": str(e)}}
|
||||
logger.error(f"Error: {e}")
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
@@ -236,9 +293,14 @@ async def flush_cache():
|
||||
|
||||
|
||||
@app.api_route("/start_profile", methods=["GET", "POST"])
|
||||
async def start_profile_async():
|
||||
async def start_profile_async(obj: Optional[ProfileReqInput] = None):
|
||||
"""Start profiling."""
|
||||
_global_state.tokenizer_manager.start_profile()
|
||||
if obj is None:
|
||||
obj = ProfileReqInput()
|
||||
|
||||
await _global_state.tokenizer_manager.start_profile(
|
||||
obj.output_dir, obj.num_steps, obj.activities
|
||||
)
|
||||
return Response(
|
||||
content="Start profiling.\n",
|
||||
status_code=200,
|
||||
@@ -257,11 +319,15 @@ async def stop_profile_async():
|
||||
|
||||
@app.post("/update_weights_from_disk")
|
||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||
"""Update the weights from disk in-place without re-launching the server."""
|
||||
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
|
||||
obj, request
|
||||
"""Update the weights from disk inplace without re-launching the server."""
|
||||
success, message, num_paused_requests = (
|
||||
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
content = {
|
||||
"success": success,
|
||||
"message": message,
|
||||
"num_paused_requests": num_paused_requests,
|
||||
}
|
||||
if success:
|
||||
return ORJSONResponse(
|
||||
content,
|
||||
@@ -323,7 +389,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
async def release_memory_occupation(
|
||||
obj: ReleaseMemoryOccupationReqInput, request: Request
|
||||
):
|
||||
"""Release GPU occupation temporarily"""
|
||||
"""Release GPU memory occupation temporarily."""
|
||||
try:
|
||||
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
|
||||
except Exception as e:
|
||||
@@ -334,7 +400,7 @@ async def release_memory_occupation(
|
||||
async def resume_memory_occupation(
|
||||
obj: ResumeMemoryOccupationReqInput, request: Request
|
||||
):
|
||||
"""Resume GPU occupation"""
|
||||
"""Resume GPU memory occupation."""
|
||||
try:
|
||||
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
|
||||
except Exception as e:
|
||||
@@ -357,7 +423,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
|
||||
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||
async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
"""Close the session"""
|
||||
"""Close the session."""
|
||||
try:
|
||||
await _global_state.tokenizer_manager.close_session(obj, request)
|
||||
return Response(status_code=200)
|
||||
@@ -367,7 +433,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
|
||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
"""Close the session"""
|
||||
"""Configure the request logging options."""
|
||||
_global_state.tokenizer_manager.configure_logging(obj)
|
||||
return Response(status_code=200)
|
||||
|
||||
@@ -511,6 +577,7 @@ def _create_error_response(e):
|
||||
def launch_server(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
"""
|
||||
Launch SRT (SGLang Runtime) Server.
|
||||
@@ -544,21 +611,23 @@ def launch_server(
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
# Send a warmup request
|
||||
t = threading.Thread(
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
_global_state.tokenizer_manager.image_token_id,
|
||||
launch_callback,
|
||||
),
|
||||
)
|
||||
t.start()
|
||||
app.warmup_thread = warmup_thread
|
||||
|
||||
try:
|
||||
# Update logging configs
|
||||
set_uvicorn_logging_configs()
|
||||
|
||||
app.server_args = server_args
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
@@ -569,10 +638,15 @@ def launch_server(
|
||||
loop="uvloop",
|
||||
)
|
||||
finally:
|
||||
t.join()
|
||||
warmup_thread.join()
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||
def _wait_and_warmup(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||
image_token_text: str,
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
@@ -614,8 +688,16 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||
else:
|
||||
json_data["text"] = "The capital city of France is"
|
||||
|
||||
# Debug dumping
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
json_data.pop("text", None)
|
||||
json_data["input_ids"] = np.load(
|
||||
server_args.debug_tensor_dump_input_file
|
||||
).tolist()
|
||||
json_data["sampling_params"]["max_new_tokens"] = 0
|
||||
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
for i in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + request_name,
|
||||
json=json_data,
|
||||
@@ -640,3 +722,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||
|
||||
if server_args.delete_ckpt_after_loading:
|
||||
delete_directory(server_args.model_path)
|
||||
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
if launch_callback is not None:
|
||||
launch_callback()
|
||||
|
||||
@@ -60,6 +60,7 @@ class VerlEngine:
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||
) -> Dict:
|
||||
@@ -76,6 +77,7 @@ class VerlEngine:
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
lora_path=lora_path,
|
||||
custom_logit_processor=custom_logit_processor,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
"""Init the metadata for a forward pass for replying a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
@@ -64,7 +64,14 @@ class AttentionBackend(ABC):
|
||||
):
|
||||
"""Run forward on an attention layer."""
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
|
||||
return self.forward_decode(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache=save_kv_cache,
|
||||
)
|
||||
else:
|
||||
return self.forward_extend(
|
||||
q,
|
||||
@@ -72,7 +79,7 @@ class AttentionBackend(ABC):
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache,
|
||||
save_kv_cache=save_kv_cache,
|
||||
)
|
||||
|
||||
def forward_decode(
|
||||
|
||||
@@ -68,6 +68,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -125,9 +126,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert self.num_wrappers == 1
|
||||
self.kv_indptr = [kv_indptr_buf]
|
||||
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
if kv_last_page_len_buf is None:
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
else:
|
||||
assert self.num_wrappers == 1
|
||||
self.kv_last_page_len = kv_last_page_len_buf
|
||||
|
||||
self.qo_indptr = [
|
||||
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
||||
for _ in range(self.num_wrappers)
|
||||
@@ -922,6 +928,9 @@ class FlashInferMultiStepDraftBackend:
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.kv_last_page_len = torch.ones(
|
||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends.append(
|
||||
@@ -929,6 +938,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
model_runner,
|
||||
skip_prefill=True,
|
||||
kv_indptr_buf=self.kv_indptr[i],
|
||||
kv_last_page_len_buf=self.kv_last_page_len,
|
||||
)
|
||||
)
|
||||
self.max_context_len = self.attn_backends[0].max_context_len
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -15,7 +15,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
|
||||
class TritonAttnBackend(AttentionBackend):
|
||||
@@ -232,7 +232,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
assert encoder_lens is None, "Not supported"
|
||||
|
||||
@@ -310,7 +310,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
if forward_mode.is_decode_or_idle():
|
||||
|
||||
@@ -1,6 +1,21 @@
|
||||
import torch
|
||||
from __future__ import annotations
|
||||
|
||||
from sglang.srt.distributed import GroupCoordinator, get_tp_group
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
GroupCoordinator,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
_ATTN_TP_GROUP = None
|
||||
_ATTN_TP_RANK = None
|
||||
@@ -69,3 +84,129 @@ def get_attention_dp_rank():
|
||||
def get_attention_dp_size():
|
||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _DP_SIZE
|
||||
|
||||
|
||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||
dp_rank = get_attention_dp_rank()
|
||||
|
||||
if forward_batch.dp_local_start_pos is None:
|
||||
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
||||
if dp_rank == 0:
|
||||
local_start_pos = torch.zeros_like(cumtokens[0])
|
||||
else:
|
||||
local_start_pos = cumtokens[dp_rank - 1]
|
||||
local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]
|
||||
|
||||
forward_batch.dp_local_start_pos = local_start_pos
|
||||
forward_batch.dp_local_num_tokens = local_num_tokens
|
||||
|
||||
return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens
|
||||
|
||||
|
||||
@triton.jit
|
||||
def memcpy_triton_kernel(
|
||||
dst_ptr,
|
||||
src_ptr,
|
||||
offset_ptr,
|
||||
sz_ptr,
|
||||
offset_src,
|
||||
chunk_size, # multiplied for offset and sz
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0).to(tl.int64)
|
||||
offset = tl.load(offset_ptr).to(tl.int64) * chunk_size
|
||||
sz = tl.load(sz_ptr).to(tl.int64) * chunk_size
|
||||
|
||||
start_index = pid * BLOCK_SIZE
|
||||
offs = tl.arange(0, BLOCK_SIZE)
|
||||
mask = start_index + offs < sz
|
||||
|
||||
if offset_src:
|
||||
data = tl.load(src_ptr + offset + start_index + offs, mask=mask)
|
||||
tl.store(dst_ptr + start_index + offs, data, mask=mask)
|
||||
else:
|
||||
data = tl.load(src_ptr + start_index + offs, mask=mask)
|
||||
tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)
|
||||
|
||||
|
||||
def prod(x):
|
||||
return functools.reduce(lambda a, b: a * b, x, 1)
|
||||
|
||||
|
||||
def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
||||
max_size = min(src.numel(), dst.numel())
|
||||
assert dim == 0, "dim != 0 unsupported"
|
||||
assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape"
|
||||
chunk_size = prod(src.shape[1:])
|
||||
BLOCK_SIZE = 8192
|
||||
grid = (triton.cdiv(max_size, BLOCK_SIZE),)
|
||||
|
||||
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
||||
|
||||
|
||||
def dp_gather(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: Union[str, int],
|
||||
):
|
||||
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||
|
||||
global_tokens.fill_(0)
|
||||
assert local_tokens.is_contiguous()
|
||||
assert global_tokens.is_contiguous()
|
||||
if local_tokens.shape[0] > 0 and (
|
||||
layer_id != "embedding" or get_attention_tp_rank() == 0
|
||||
):
|
||||
assert (
|
||||
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
|
||||
), "aliasing between global_tokens and local_tokens not allowed"
|
||||
memcpy_triton(
|
||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||
)
|
||||
|
||||
# Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
|
||||
NUM_GPUS_PER_NODE = 8
|
||||
if (
|
||||
not local_tokens.dtype.is_floating_point
|
||||
and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE
|
||||
):
|
||||
torch.ops.sglang.inplace_all_reduce(
|
||||
global_tokens, group_name=get_tp_group().unique_name
|
||||
)
|
||||
else:
|
||||
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
||||
|
||||
|
||||
def dp_scatter(
|
||||
local_tokens: torch.Tensor, # output
|
||||
global_tokens: torch.Tensor, # input
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
||||
# since local_tokens may be padded for cuda graph
|
||||
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||
local_tokens.fill_(0)
|
||||
assert local_tokens.is_contiguous()
|
||||
assert global_tokens.is_contiguous()
|
||||
if local_tokens.shape[0] > 0:
|
||||
assert (
|
||||
local_tokens.untyped_storage().data_ptr()
|
||||
!= global_tokens.untyped_storage().data_ptr()
|
||||
), "aliasing between local_tokens and global_tokens not allowed"
|
||||
memcpy_triton(
|
||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||
)
|
||||
|
||||
|
||||
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
|
||||
def do_logits_dp_scatter(logits: torch.Tensor):
|
||||
local_logits = torch.empty(
|
||||
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
|
||||
dtype=logits.dtype,
|
||||
device=logits.device,
|
||||
)
|
||||
dp_scatter(local_logits, logits, forward_batch)
|
||||
return local_logits
|
||||
|
||||
return do_logits_dp_scatter
|
||||
|
||||
@@ -69,7 +69,7 @@ class RMSNorm(CustomOp):
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
x = (x * self.weight).to(orig_dtype)
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
|
||||
@@ -426,13 +426,14 @@ class ColumnParallelLinear(LinearBase):
|
||||
from sglang.srt.layers.parameter import _ColumnvLLMParameter
|
||||
|
||||
if isinstance(param, _ColumnvLLMParameter):
|
||||
# FIXME: why would we need this special case?
|
||||
param.load_column_parallel_weight(
|
||||
loaded_weight,
|
||||
tp_rank=self.tp_rank,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
else:
|
||||
# FIXME: This branch is needed to load deepseek v3 awq.
|
||||
# However, we should fix this and avoid the branching here.
|
||||
param.load_column_parallel_weight(loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
|
||||
@@ -26,12 +26,19 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
dp_gather,
|
||||
dp_scatter,
|
||||
get_attention_dp_rank,
|
||||
get_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.utils import dump_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +58,9 @@ class LogitsProcessorOutput:
|
||||
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
||||
next_token_top_logprobs_val: Optional[List] = None
|
||||
next_token_top_logprobs_idx: Optional[List] = None
|
||||
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
|
||||
next_token_token_ids_logprobs_val: Optional[List] = None
|
||||
next_token_token_ids_logprobs_idx: Optional[List] = None
|
||||
|
||||
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||
# The logprobs of input tokens. shape: [#token]
|
||||
@@ -58,6 +68,9 @@ class LogitsProcessorOutput:
|
||||
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
||||
input_top_logprobs_val: List = None
|
||||
input_top_logprobs_idx: List = None
|
||||
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
||||
input_token_ids_logprobs_val: Optional[List] = None
|
||||
input_token_ids_logprobs_idx: Optional[List] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -67,43 +80,107 @@ class LogitsMetadata:
|
||||
|
||||
extend_return_logprob: bool = False
|
||||
extend_return_top_logprob: bool = False
|
||||
extend_token_ids_logprob: bool = False
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# logits and logprobs post processing
|
||||
temp_scaled_logprobs: bool = False
|
||||
temperature: torch.Tensor = None
|
||||
top_p_normalized_logprobs: bool = False
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
# DP attention metadata. Not needed when DP attention is not used.
|
||||
# Number of tokens in the request.
|
||||
global_num_tokens_gpu: Optional[torch.Tensor] = None
|
||||
# The start position of local hidden states.
|
||||
dp_local_start_pos: Optional[torch.Tensor] = None
|
||||
dp_local_num_tokens: Optional[torch.Tensor] = None
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
# Buffer to gather logits from all ranks.
|
||||
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
|
||||
# Number of tokens to sample per DP rank
|
||||
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
|
||||
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
||||
|
||||
# for padding
|
||||
padded_static_len: int = -1
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
|
||||
extend_return_logprob = True
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and forward_batch.return_logprob
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
):
|
||||
extend_return_top_logprob = any(
|
||||
x > 0 for x in forward_batch.top_logprobs_nums
|
||||
)
|
||||
extend_logprob_pruned_lens_cpu = [
|
||||
extend_len - start_len
|
||||
for extend_len, start_len in zip(
|
||||
forward_batch.extend_seq_lens_cpu,
|
||||
forward_batch.extend_logprob_start_lens_cpu,
|
||||
)
|
||||
]
|
||||
extend_token_ids_logprob = any(
|
||||
x is not None for x in forward_batch.token_ids_logprobs
|
||||
)
|
||||
extend_return_logprob = False
|
||||
extend_logprob_pruned_lens_cpu = []
|
||||
for extend_len, start_len in zip(
|
||||
forward_batch.extend_seq_lens_cpu,
|
||||
forward_batch.extend_logprob_start_lens_cpu,
|
||||
):
|
||||
if extend_len - start_len > 0:
|
||||
extend_return_logprob = True
|
||||
extend_logprob_pruned_lens_cpu.append(extend_len - start_len)
|
||||
else:
|
||||
extend_return_logprob = extend_return_top_logprob = (
|
||||
extend_logprob_pruned_lens_cpu
|
||||
) = False
|
||||
extend_token_ids_logprob
|
||||
) = extend_logprob_pruned_lens_cpu = False
|
||||
|
||||
return cls(
|
||||
forward_mode=forward_batch.forward_mode,
|
||||
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
||||
extend_return_logprob=extend_return_logprob,
|
||||
extend_return_top_logprob=extend_return_top_logprob,
|
||||
extend_token_ids_logprob=extend_token_ids_logprob,
|
||||
extend_seq_lens=forward_batch.extend_seq_lens,
|
||||
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
||||
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
||||
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
||||
padded_static_len=forward_batch.padded_static_len,
|
||||
)
|
||||
|
||||
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
|
||||
if self.global_num_tokens_for_logprob_cpu is None:
|
||||
# we are capturing cuda graph
|
||||
return
|
||||
|
||||
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
||||
dp_rank = get_attention_dp_rank()
|
||||
if dp_rank == 0:
|
||||
dp_local_start_pos = torch.zeros_like(
|
||||
self.global_num_tokens_for_logprob_gpu[0]
|
||||
)
|
||||
else:
|
||||
dp_local_start_pos = cumtokens[dp_rank - 1]
|
||||
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
||||
gathered_buffer = torch.zeros(
|
||||
(
|
||||
sum(self.global_num_tokens_for_logprob_cpu),
|
||||
hidden_states.shape[1],
|
||||
),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
self.dp_local_start_pos = dp_local_start_pos
|
||||
self.dp_local_num_tokens = dp_local_num_tokens
|
||||
self.gathered_buffer = gathered_buffer
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
def __init__(
|
||||
@@ -115,6 +192,9 @@ class LogitsProcessor(nn.Module):
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
||||
)
|
||||
self.do_tensor_parallel_all_gather_dp_attn = (
|
||||
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
||||
)
|
||||
self.final_logit_softcapping = getattr(
|
||||
self.config, "final_logit_softcapping", None
|
||||
)
|
||||
@@ -124,6 +204,12 @@ class LogitsProcessor(nn.Module):
|
||||
):
|
||||
self.final_logit_softcapping = None
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
self.debug_tensor_dump_output_folder = global_server_args_dict[
|
||||
"debug_tensor_dump_output_folder"
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -141,30 +227,74 @@ class LogitsProcessor(nn.Module):
|
||||
):
|
||||
pruned_states = hidden_states
|
||||
sample_indices = None
|
||||
input_logprob_indices = None
|
||||
elif (
|
||||
logits_metadata.forward_mode.is_extend()
|
||||
and not logits_metadata.extend_return_logprob
|
||||
):
|
||||
# Prefill without input logprobs.
|
||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||
if logits_metadata.padded_static_len < 0:
|
||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||
else:
|
||||
# If padding_static length is 5 and extended_seq_lens is [2, 3],
|
||||
# then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p]
|
||||
# and this retrieves t01 and t12, which are the valid last tokens
|
||||
idx = torch.arange(
|
||||
len(logits_metadata.extend_seq_lens),
|
||||
device=logits_metadata.extend_seq_lens.device,
|
||||
)
|
||||
last_index = (
|
||||
idx * logits_metadata.padded_static_len
|
||||
+ logits_metadata.extend_seq_lens
|
||||
- 1
|
||||
)
|
||||
pruned_states = hidden_states[last_index]
|
||||
sample_indices = None
|
||||
input_logprob_indices = None
|
||||
else:
|
||||
# Slice the requested tokens to compute logprob
|
||||
# Input logprobs are required.
|
||||
# Find 3 different indices.
|
||||
# 1. pruned_states: hidden states that we want logprobs from.
|
||||
# 2. sample_indices: Indices that have sampled tokens.
|
||||
# 3. input_logprob_indices: Indices that have input logprob tokens.
|
||||
sample_index_pt = -1
|
||||
sample_indices = []
|
||||
pt, pruned_states, pruned_input_ids = 0, [], []
|
||||
for start_len, extend_len in zip(
|
||||
input_logprob_indices_pt = 0
|
||||
input_logprob_indices = []
|
||||
pt, pruned_states = 0, []
|
||||
for extend_logprob_start_len, extend_len in zip(
|
||||
logits_metadata.extend_logprob_start_lens_cpu,
|
||||
logits_metadata.extend_seq_lens_cpu,
|
||||
):
|
||||
# It can happen in chunked prefill. We still need to sample 1 token,
|
||||
# But we don't want to include it in input logprob.
|
||||
if extend_len == extend_logprob_start_len:
|
||||
start_len = extend_logprob_start_len - 1
|
||||
else:
|
||||
start_len = extend_logprob_start_len
|
||||
|
||||
# We always need at least 1 token to sample because that's required
|
||||
# by a caller.
|
||||
assert extend_len > start_len
|
||||
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||
pt += extend_len
|
||||
sample_index_pt += extend_len - start_len
|
||||
sample_indices.append(sample_index_pt)
|
||||
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
||||
pt += extend_len
|
||||
input_logprob_indices.extend(
|
||||
[
|
||||
input_logprob_indices_pt + i
|
||||
for i in range(extend_len - extend_logprob_start_len)
|
||||
]
|
||||
)
|
||||
input_logprob_indices_pt += extend_len - start_len
|
||||
|
||||
pruned_states = torch.cat(pruned_states)
|
||||
sample_indices = torch.tensor(
|
||||
sample_indices, device=pruned_states.device, dtype=torch.int64
|
||||
)
|
||||
input_logprob_indices = torch.tensor(
|
||||
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
|
||||
)
|
||||
|
||||
# Compute logits for both input and sampled tokens.
|
||||
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
||||
@@ -172,28 +302,51 @@ class LogitsProcessor(nn.Module):
|
||||
logits[sample_indices] if sample_indices is not None else logits
|
||||
)
|
||||
|
||||
if (
|
||||
not logits_metadata.extend_return_logprob
|
||||
or logits_metadata.capture_hidden_mode.need_capture()
|
||||
):
|
||||
if self.debug_tensor_dump_output_folder:
|
||||
assert (
|
||||
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
|
||||
), "dp attention + sharded lm_head doesn't support full logits"
|
||||
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
||||
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
||||
|
||||
hidden_states_to_store: Optional[torch.Tensor] = None
|
||||
if logits_metadata.capture_hidden_mode.need_capture():
|
||||
if logits_metadata.capture_hidden_mode.is_full():
|
||||
hidden_states_to_store = hidden_states
|
||||
elif logits_metadata.capture_hidden_mode.is_last():
|
||||
# Get the last token hidden states. If sample_indices is None,
|
||||
# pruned states only contain the last tokens already.
|
||||
hidden_states_to_store = (
|
||||
pruned_states[sample_indices] if sample_indices else pruned_states
|
||||
)
|
||||
else:
|
||||
assert False, "Should never reach"
|
||||
|
||||
if not logits_metadata.extend_return_logprob:
|
||||
# Decode mode or extend mode without return_logprob.
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=sampled_logits,
|
||||
hidden_states=(
|
||||
hidden_states
|
||||
if logits_metadata.capture_hidden_mode.is_full()
|
||||
else (
|
||||
pruned_states
|
||||
if logits_metadata.capture_hidden_mode.is_last()
|
||||
else None
|
||||
)
|
||||
),
|
||||
hidden_states=hidden_states_to_store,
|
||||
)
|
||||
else:
|
||||
input_logprobs = logits
|
||||
input_logprobs = logits[input_logprob_indices]
|
||||
del hidden_states, logits
|
||||
|
||||
# Normalize the logprob w/o temperature, top-p
|
||||
pruned_lens = torch.tensor(
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
device=input_logprobs.device,
|
||||
)
|
||||
if logits_metadata.temp_scaled_logprobs:
|
||||
logits_metadata.temperature = torch.repeat_interleave(
|
||||
logits_metadata.temperature.view(-1),
|
||||
pruned_lens,
|
||||
).view(-1, 1)
|
||||
if logits_metadata.top_p_normalized_logprobs:
|
||||
logits_metadata.top_p = torch.repeat_interleave(
|
||||
logits_metadata.top_p,
|
||||
pruned_lens,
|
||||
)
|
||||
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||
input_logprobs, logits_metadata
|
||||
)
|
||||
@@ -207,14 +360,18 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
input_top_logprobs_val = input_top_logprobs_idx = None
|
||||
|
||||
# Get the logprob of given token id
|
||||
if logits_metadata.extend_token_ids_logprob:
|
||||
(
|
||||
input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx,
|
||||
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
|
||||
else:
|
||||
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
|
||||
|
||||
input_token_logprobs = input_logprobs[
|
||||
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
||||
torch.cat(
|
||||
[
|
||||
torch.cat(pruned_input_ids)[1:],
|
||||
torch.tensor([0], device=input_logprobs.device),
|
||||
]
|
||||
),
|
||||
logits_metadata.extend_input_logprob_token_ids_gpu,
|
||||
]
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
@@ -222,6 +379,9 @@ class LogitsProcessor(nn.Module):
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs_val=input_top_logprobs_val,
|
||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
hidden_states=hidden_states_to_store,
|
||||
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
||||
)
|
||||
|
||||
def _get_logits(
|
||||
@@ -231,10 +391,24 @@ class LogitsProcessor(nn.Module):
|
||||
logits_metadata: LogitsMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Get logits from hidden_states."""
|
||||
"""Get logits from hidden_states.
|
||||
|
||||
If sampled_logits_only is True, it means hidden_states only contain the
|
||||
last position (e.g., extend without input logprobs). The caller should
|
||||
guarantee the given hidden_states follow this constraint.
|
||||
"""
|
||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||
logits_metadata.compute_dp_attention_metadata(hidden_states)
|
||||
hidden_states, local_hidden_states = (
|
||||
logits_metadata.gathered_buffer,
|
||||
hidden_states.clone(),
|
||||
)
|
||||
dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
|
||||
|
||||
if hasattr(lm_head, "weight"):
|
||||
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
||||
logits = torch.matmul(
|
||||
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
|
||||
)
|
||||
else:
|
||||
# GGUF models
|
||||
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
||||
@@ -245,6 +419,17 @@ class LogitsProcessor(nn.Module):
|
||||
if self.do_tensor_parallel_all_gather:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
|
||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||
logits, global_logits = (
|
||||
torch.empty(
|
||||
(local_hidden_states.shape[0], logits.shape[1]),
|
||||
device=logits.device,
|
||||
dtype=logits.dtype,
|
||||
),
|
||||
logits,
|
||||
)
|
||||
dp_scatter(logits, global_logits, logits_metadata)
|
||||
|
||||
logits = logits[:, : self.config.vocab_size].float()
|
||||
|
||||
if self.final_logit_softcapping:
|
||||
@@ -272,21 +457,66 @@ class LogitsProcessor(nn.Module):
|
||||
continue
|
||||
|
||||
input_top_logprobs_val.append(
|
||||
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
||||
[values[pt + j][:k] for j in range(pruned_len)]
|
||||
)
|
||||
input_top_logprobs_idx.append(
|
||||
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
||||
[indices[pt + j][:k] for j in range(pruned_len)]
|
||||
)
|
||||
pt += pruned_len
|
||||
|
||||
return input_top_logprobs_val, input_top_logprobs_idx
|
||||
|
||||
@staticmethod
|
||||
def get_token_ids_logprobs(
|
||||
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
||||
):
|
||||
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
||||
pt = 0
|
||||
for token_ids, pruned_len in zip(
|
||||
logits_metadata.token_ids_logprobs,
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
):
|
||||
if pruned_len <= 0:
|
||||
input_token_ids_logprobs_val.append([])
|
||||
input_token_ids_logprobs_idx.append([])
|
||||
continue
|
||||
|
||||
input_token_ids_logprobs_val.append(
|
||||
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
||||
)
|
||||
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
||||
pt += pruned_len
|
||||
|
||||
return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
|
||||
|
||||
@staticmethod
|
||||
def compute_temp_top_p_normalized_logprobs(
|
||||
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
||||
) -> torch.Tensor:
|
||||
# TODO: Implement the temp and top-p normalization
|
||||
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||
"""
|
||||
compute logprobs for the output token from the given logits.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: logprobs from logits
|
||||
"""
|
||||
# Scale logits if temperature scaling is enabled
|
||||
if logits_metadata.temp_scaled_logprobs:
|
||||
last_logits = last_logits / logits_metadata.temperature
|
||||
|
||||
# Normalize logprobs if top_p normalization is enabled
|
||||
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
|
||||
if (
|
||||
logits_metadata.top_p_normalized_logprobs
|
||||
and (logits_metadata.top_p != 1.0).any()
|
||||
):
|
||||
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
|
||||
|
||||
probs = torch.softmax(last_logits, dim=-1)
|
||||
del last_logits
|
||||
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
|
||||
return torch.log(probs)
|
||||
else:
|
||||
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel(
|
||||
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu_and_mul_triton_kernel(
|
||||
gateup_output,
|
||||
down_input,
|
||||
hidden_size,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = gateup_output.dtype.element_ty
|
||||
OutDtype = down_input.dtype.element_ty
|
||||
|
||||
half_hidden_size = hidden_size // 2
|
||||
|
||||
pid = tl.program_id(0)
|
||||
expert_id = tl.load(reorder_topk_ids + pid)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
gateup_output_ptr = gateup_output + pid * hidden_size
|
||||
gate_output_ptr = gateup_output_ptr
|
||||
up_output_ptr = gateup_output_ptr + half_hidden_size
|
||||
down_input_ptr = down_input + pid * half_hidden_size
|
||||
|
||||
if scales is not None:
|
||||
scale = tl.load(scales + expert_id - start_expert_id)
|
||||
scale = (1 / scale).to(InDtype)
|
||||
else:
|
||||
scale = 1
|
||||
|
||||
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < half_hidden_size
|
||||
|
||||
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
||||
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
||||
|
||||
# gelu & mul & quantize
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
||||
# sqrt(2/pi)
|
||||
kAlpha = 0.7978845608028654
|
||||
gate_output = (
|
||||
0.5
|
||||
* gate_output
|
||||
* (
|
||||
1
|
||||
+ tanh(
|
||||
kAlpha
|
||||
* (
|
||||
gate_output
|
||||
+ 0.044715 * gate_output * gate_output * gate_output
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
gate_output = gate_output.to(InDtype)
|
||||
|
||||
gelu_mul_output = gate_output * up_output * scale
|
||||
gelu_mul_output = gelu_mul_output.to(OutDtype)
|
||||
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def post_reorder_triton_kernel(
|
||||
down_output_ptr,
|
||||
|
||||
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
gelu_and_mul_triton_kernel,
|
||||
grouped_gemm_triton,
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
@@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module):
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
elif self.activation == "gelu":
|
||||
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ def fused_moe_forward_native(
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
|
||||
@@ -23,7 +23,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
)
|
||||
|
||||
is_hip_flag = is_hip()
|
||||
is_hip_ = is_hip()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -487,6 +487,7 @@ def invoke_fused_moe_kernel(
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
@@ -646,7 +647,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2 if is_hip_flag else 4,
|
||||
"num_stages": 2 if is_hip_ else 4,
|
||||
}
|
||||
if M <= E:
|
||||
config = {
|
||||
@@ -655,7 +656,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2 if is_hip_flag else 4,
|
||||
"num_stages": 2 if is_hip_ else 4,
|
||||
}
|
||||
else:
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
||||
@@ -665,7 +666,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2 if is_hip_flag else 3,
|
||||
"num_stages": 2 if is_hip_ else 3,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
@@ -814,6 +815,7 @@ def outplace_fused_experts(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -831,6 +833,7 @@ def outplace_fused_experts(
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
block_shape,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
|
||||
@@ -849,6 +852,7 @@ def outplace_fused_experts_fake(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -877,8 +881,10 @@ def fused_experts(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
if inplace:
|
||||
assert not no_combine, "no combine + inplace makes no sense"
|
||||
torch.ops.sglang.inplace_fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
@@ -912,6 +918,7 @@ def fused_experts(
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
block_shape,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
|
||||
@@ -931,6 +938,7 @@ def fused_experts_impl(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
padded_size = padding_size
|
||||
if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None:
|
||||
@@ -987,7 +995,14 @@ def fused_experts_impl(
|
||||
|
||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
|
||||
if inplace:
|
||||
if no_combine:
|
||||
assert not inplace
|
||||
out_hidden_states = torch.empty(
|
||||
(num_tokens, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
elif inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
@@ -1057,7 +1072,11 @@ def fused_experts_impl(
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
(
|
||||
intermediate_cache3
|
||||
if not no_combine and topk_ids.shape[1] != 1
|
||||
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
),
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
curr_topk_weights,
|
||||
@@ -1075,16 +1094,16 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
if is_hip_flag:
|
||||
if no_combine:
|
||||
pass
|
||||
elif is_hip_:
|
||||
ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
else:
|
||||
if topk_ids.shape[1] == 1:
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
|
||||
intermediate_cache3[:, 0]
|
||||
)
|
||||
pass # we write directly into out_hidden_states
|
||||
elif topk_ids.shape[1] == 2:
|
||||
torch.add(
|
||||
intermediate_cache3[:, 0],
|
||||
@@ -1122,6 +1141,7 @@ def fused_moe(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -1191,4 +1211,5 @@ def fused_moe(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
@@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return self.forward(
|
||||
x=x,
|
||||
@@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
activation=activation,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
@@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
from aiter.fused_moe import fused_experts_ck
|
||||
|
||||
assert activation == "silu", f"{activation=} is not supported."
|
||||
assert not no_combine, "unsupported"
|
||||
|
||||
return fused_experts_ck(
|
||||
hidden_states=x,
|
||||
@@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
@@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
inplace: bool = True,
|
||||
) -> torch.Tensor:
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
@@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module):
|
||||
reduce_results: Whether to all all_reduce on the output of the layer
|
||||
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
||||
quant_config: Quantization configure.
|
||||
inplace: suggestion to compute inplace (modify input activation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module):
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
use_presharded_weights: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module):
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
self.activation = activation
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
self.inplace = inplace
|
||||
self.no_combine = no_combine
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
@@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
@@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module):
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
activation=self.activation,
|
||||
inplace=self.inplace,
|
||||
no_combine=self.no_combine,
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
|
||||
@@ -771,6 +771,8 @@ class Fp8MoEMethod:
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
@@ -793,6 +795,7 @@ class Fp8MoEMethod:
|
||||
from aiter.fused_moe import fused_experts_ck
|
||||
|
||||
assert activation == "silu", f"{activation=} is not supported."
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
|
||||
return fused_experts_ck(
|
||||
x,
|
||||
@@ -823,7 +826,7 @@ class Fp8MoEMethod:
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(
|
||||
@@ -839,6 +842,7 @@ class Fp8MoEMethod:
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -707,7 +707,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
cos = freqs.cos() * self.mscale
|
||||
sin = freqs.sin() * self.mscale
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
print("Cache shape", cache.shape)
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -41,7 +41,21 @@ class Sampler(nn.Module):
|
||||
sampling_info: SamplingBatchInfo,
|
||||
return_logprob: bool,
|
||||
top_logprobs_nums: List[int],
|
||||
token_ids_logprobs: List[List[int]],
|
||||
batch_next_token_ids: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
||||
|
||||
Args:
|
||||
logits_output: The logits from the model forward
|
||||
sampling_info: Metadata for sampling
|
||||
return_logprob: If set, store the output logprob information to
|
||||
logits_output
|
||||
top_logprobs_nums: Number of top lobprobs per sequence in a batch
|
||||
batch_next_token_ids: next token IDs. If set, skip sampling and only
|
||||
compute output logprobs It is used for speculative decoding which
|
||||
performs sampling in draft workers.
|
||||
"""
|
||||
logits = logits_output.next_token_logits
|
||||
|
||||
# Apply the custom logit processors if registered in the sampling info.
|
||||
@@ -58,13 +72,15 @@ class Sampler(nn.Module):
|
||||
|
||||
if sampling_info.is_all_greedy:
|
||||
# Use torch.argmax if all requests use greedy sampling
|
||||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
if batch_next_token_ids is None:
|
||||
batch_next_token_ids = torch.argmax(logits, -1)
|
||||
if return_logprob:
|
||||
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
else:
|
||||
# Post process logits
|
||||
logits.div_(sampling_info.temperatures)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
logits[:] = torch.softmax(logits, dim=-1)
|
||||
probs = logits
|
||||
del logits
|
||||
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
@@ -78,38 +94,43 @@ class Sampler(nn.Module):
|
||||
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
||||
).clamp(min=torch.finfo(probs.dtype).min)
|
||||
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
batch_next_token_ids = min_p_sampling_from_probs(
|
||||
probs, uniform_samples, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
filter_apply_order="joint",
|
||||
if batch_next_token_ids is None:
|
||||
max_top_k_round, batch_size = 32, probs.shape[0]
|
||||
uniform_samples = torch.rand(
|
||||
(max_top_k_round, batch_size), device=probs.device
|
||||
)
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
batch_next_token_ids = min_p_sampling_from_probs(
|
||||
probs, uniform_samples, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
if self.use_nan_detection and not torch.all(success):
|
||||
logger.warning("Detected errors during sampling!")
|
||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||
if self.use_nan_detection and not torch.all(success):
|
||||
logger.warning("Detected errors during sampling!")
|
||||
batch_next_token_ids = torch.zeros_like(
|
||||
batch_next_token_ids
|
||||
)
|
||||
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# A slower fallback implementation with torch native operations.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
sampling_info.min_ps,
|
||||
sampling_info.need_min_p_sampling,
|
||||
)
|
||||
if batch_next_token_ids is None:
|
||||
# A slower fallback implementation with torch native operations.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
sampling_info.min_ps,
|
||||
sampling_info.need_min_p_sampling,
|
||||
)
|
||||
|
||||
if return_logprob:
|
||||
# clamp to avoid -inf
|
||||
logprobs = torch.log(
|
||||
@@ -128,6 +149,12 @@ class Sampler(nn.Module):
|
||||
logits_output.next_token_top_logprobs_idx,
|
||||
) = get_top_logprobs(logprobs, top_logprobs_nums)
|
||||
|
||||
if any(x is not None for x in token_ids_logprobs):
|
||||
(
|
||||
logits_output.next_token_token_ids_logprobs_val,
|
||||
logits_output.next_token_token_ids_logprobs_idx,
|
||||
) = get_token_ids_logprobs(logprobs, token_ids_logprobs)
|
||||
|
||||
logits_output.next_token_logprobs = logprobs[
|
||||
torch.arange(len(batch_next_token_ids), device=sampling_info.device),
|
||||
batch_next_token_ids,
|
||||
@@ -223,6 +250,10 @@ def top_p_normalize_probs_torch(
|
||||
|
||||
|
||||
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
||||
assert len(top_logprobs_nums) == logprobs.shape[0], (
|
||||
len(top_logprobs_nums),
|
||||
logprobs.shape[0],
|
||||
)
|
||||
max_k = max(top_logprobs_nums)
|
||||
ret = logprobs.topk(max_k, dim=1)
|
||||
values = ret.values.tolist()
|
||||
@@ -234,3 +265,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
||||
output_top_logprobs_val.append(values[i][:k])
|
||||
output_top_logprobs_idx.append(indices[i][:k])
|
||||
return output_top_logprobs_val, output_top_logprobs_idx
|
||||
|
||||
|
||||
def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
|
||||
output_token_ids_logprobs_val = []
|
||||
output_token_ids_logprobs_idx = []
|
||||
for i, token_ids in enumerate(token_ids_logprobs):
|
||||
if token_ids is not None:
|
||||
output_token_ids_logprobs_val.append(logprobs[i, token_ids].tolist())
|
||||
output_token_ids_logprobs_idx.append(token_ids)
|
||||
else:
|
||||
output_token_ids_logprobs_val.append([])
|
||||
output_token_ids_logprobs_idx.append([])
|
||||
|
||||
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
|
||||
|
||||
@@ -457,7 +457,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
assert loaded_weight.shape[output_dim] == (
|
||||
self.org_vocab_size
|
||||
// (self.tp_size if self.use_presharded_weights else 1)
|
||||
)
|
||||
), f"{self.org_vocab_size=} {self.use_presharded_weights=} {loaded_weight.shape[output_dim]=}"
|
||||
|
||||
# Copy the data.
|
||||
if not self.use_presharded_weights:
|
||||
|
||||
@@ -28,6 +28,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
||||
parser.add_argument("--log-requests", action="store_true")
|
||||
parser.add_argument("--log-requests-level", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
||||
)
|
||||
@@ -38,7 +39,7 @@ if __name__ == "__main__":
|
||||
args.url + "/configure_logging",
|
||||
json={
|
||||
"log_requests": args.log_requests,
|
||||
"log_requests_level": 1, # Log full requests
|
||||
"log_requests_level": args.log_requests_level, # Log full requests
|
||||
"dump_requests_folder": args.dump_requests_folder,
|
||||
"dump_requests_threshold": args.dump_requests_threshold,
|
||||
},
|
||||
|
||||
@@ -198,6 +198,8 @@ class DataParallelController:
|
||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
|
||||
|
||||
print(f"{scheduler_info=}")
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||
@@ -220,6 +222,7 @@ class DataParallelController:
|
||||
TokenizedEmbeddingReqInput,
|
||||
),
|
||||
):
|
||||
logger.info("dispatching")
|
||||
self.dispatching(recv_req)
|
||||
else:
|
||||
# Send other control messages to first worker of tp group
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
@@ -27,11 +28,16 @@ import zmq
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalDecodeReq,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import configure_logger, get_zmq_socket
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
)
|
||||
from sglang.utils import (
|
||||
TypeBasedDispatcher,
|
||||
find_printable_text,
|
||||
@@ -86,14 +92,23 @@ class DetokenizerManager:
|
||||
)
|
||||
|
||||
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
||||
self.is_dummy = server_args.load_format == "dummy"
|
||||
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||
]
|
||||
)
|
||||
|
||||
def event_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
while True:
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def trim_matched_stop(
|
||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||
):
|
||||
@@ -117,14 +132,6 @@ class DetokenizerManager:
|
||||
return output[:-1]
|
||||
return output
|
||||
|
||||
def event_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
|
||||
# If it is embedding model, no detokenization is needed.
|
||||
return recv_obj
|
||||
@@ -173,7 +180,6 @@ class DetokenizerManager:
|
||||
|
||||
# Incremental decoding
|
||||
output_strs = []
|
||||
finished_reqs = []
|
||||
for i in range(bs):
|
||||
try:
|
||||
s = self.decode_status[recv_obj.rids[i]]
|
||||
@@ -196,8 +202,6 @@ class DetokenizerManager:
|
||||
new_text = ""
|
||||
else:
|
||||
new_text = find_printable_text(new_text)
|
||||
else:
|
||||
finished_reqs.append(recv_obj.rids[i])
|
||||
|
||||
output_strs.append(
|
||||
self.trim_matched_stop(
|
||||
@@ -207,7 +211,7 @@ class DetokenizerManager:
|
||||
)
|
||||
)
|
||||
|
||||
out = BatchStrOut(
|
||||
return BatchStrOut(
|
||||
rids=recv_obj.rids,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
output_strs=output_strs,
|
||||
@@ -223,14 +227,15 @@ class DetokenizerManager:
|
||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||
input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
|
||||
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
||||
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
||||
output_hidden_states=recv_obj.output_hidden_states,
|
||||
)
|
||||
|
||||
# remove decodestatus for completed requests
|
||||
for rid in finished_reqs:
|
||||
self.decode_status.pop(rid)
|
||||
|
||||
return out
|
||||
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LimitedCapacityDict(OrderedDict):
|
||||
@@ -250,6 +255,7 @@ def run_detokenizer_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
kill_itself_when_parent_died()
|
||||
setproctitle.setproctitle("sglang::detokenizer")
|
||||
configure_logger(server_args)
|
||||
parent_process = psutil.Process().parent()
|
||||
|
||||
@@ -16,10 +16,11 @@ The definition of objects transfered between different
|
||||
processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
"""
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -55,6 +56,8 @@ class GenerateReqInput:
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
# If return logprobs, the number of top logprobs to return at each position.
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||
# If return logprobs, the token ids to return logprob for.
|
||||
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# Whether to detokenize tokens in text in the returned logprobs.
|
||||
return_text_in_logprobs: bool = False
|
||||
# Whether to stream output.
|
||||
@@ -146,6 +149,8 @@ class GenerateReqInput:
|
||||
self.logprob_start_len = -1
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = None
|
||||
else:
|
||||
if self.parallel_sample_num == 1:
|
||||
num = self.batch_size
|
||||
@@ -191,6 +196,17 @@ class GenerateReqInput:
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
if not self.token_ids_logprob: # covers both None and []
|
||||
self.token_ids_logprob = [None] * num
|
||||
elif not isinstance(self.token_ids_logprob, list):
|
||||
self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)]
|
||||
elif not isinstance(self.token_ids_logprob[0], list):
|
||||
self.token_ids_logprob = [
|
||||
copy.deepcopy(self.token_ids_logprob) for _ in range(num)
|
||||
]
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
if self.custom_logit_processor is None:
|
||||
self.custom_logit_processor = [None] * num
|
||||
elif not isinstance(self.custom_logit_processor, list):
|
||||
@@ -198,6 +214,12 @@ class GenerateReqInput:
|
||||
else:
|
||||
assert self.parallel_sample_num == 1
|
||||
|
||||
# Other checks
|
||||
if self.session_params is not None:
|
||||
assert isinstance(self.session_params, dict) or isinstance(
|
||||
self.session_params[0], dict
|
||||
)
|
||||
|
||||
def regenerate_rid(self):
|
||||
self.rid = uuid.uuid4().hex
|
||||
return self.rid
|
||||
@@ -212,6 +234,7 @@ class GenerateReqInput:
|
||||
return_logprob=self.return_logprob[i],
|
||||
logprob_start_len=self.logprob_start_len[i],
|
||||
top_logprobs_num=self.top_logprobs_num[i],
|
||||
token_ids_logprob=self.token_ids_logprob[i],
|
||||
return_text_in_logprobs=self.return_text_in_logprobs,
|
||||
stream=self.stream,
|
||||
log_metrics=self.log_metrics,
|
||||
@@ -244,6 +267,8 @@ class TokenizedGenerateReqInput:
|
||||
logprob_start_len: int
|
||||
# If return logprobs, the number of top logprobs to return at each position.
|
||||
top_logprobs_num: int
|
||||
# If return logprobs, the token id to return logprob for
|
||||
token_ids_logprob: List[int]
|
||||
# Whether to stream output
|
||||
stream: bool
|
||||
|
||||
@@ -378,10 +403,21 @@ class BatchTokenIDOut:
|
||||
input_top_logprobs_idx: List[List]
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
input_token_ids_logprobs_val: List[List]
|
||||
input_token_ids_logprobs_idx: List[List]
|
||||
output_token_ids_logprobs_val: List[List]
|
||||
output_token_ids_logprobs_idx: List[List]
|
||||
|
||||
# Hidden states
|
||||
output_hidden_states: List[List[float]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchMultimodalDecodeReq:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStrOut:
|
||||
# The request id
|
||||
@@ -406,10 +442,21 @@ class BatchStrOut:
|
||||
input_top_logprobs_idx: List[List]
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
input_token_ids_logprobs_val: List[List]
|
||||
input_token_ids_logprobs_idx: List[List]
|
||||
output_token_ids_logprobs_val: List[List]
|
||||
output_token_ids_logprobs_idx: List[List]
|
||||
|
||||
# Hidden states
|
||||
output_hidden_states: List[List[float]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchMultimodalOut:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchEmbeddingOut:
|
||||
# The request id
|
||||
@@ -439,6 +486,8 @@ class UpdateWeightFromDiskReqInput:
|
||||
class UpdateWeightFromDiskReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
# Number of paused requests during weight sync.
|
||||
num_paused_requests: Optional[int] = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -526,11 +575,57 @@ class AbortReq:
|
||||
rid: str
|
||||
|
||||
|
||||
class ProfileReq(Enum):
|
||||
@dataclass
|
||||
class GetInternalStateReq:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetInternalStateReqOutput:
|
||||
internal_state: Dict[Any, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetInternalStateReq:
|
||||
server_args: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetInternalStateReqOutput:
|
||||
updated: bool
|
||||
server_args: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReqInput:
|
||||
# The output directory
|
||||
output_dir: Optional[str] = None
|
||||
# If set, it profile as many as this number of steps.
|
||||
# If it is set, profiling is automatically stopped after this step, and
|
||||
# the caller doesn't need to run stop_profile.
|
||||
num_steps: Optional[int] = None
|
||||
activities: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ProfileReqType(Enum):
|
||||
START_PROFILE = 1
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReq:
|
||||
type: ProfileReqType
|
||||
output_dir: Optional[str] = None
|
||||
num_steps: Optional[int] = None
|
||||
activities: Optional[List[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigureLoggingReq:
|
||||
log_requests: Optional[bool] = None
|
||||
@@ -556,6 +651,11 @@ class OpenSessionReqOutput:
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthCheckOutput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
description: Optional[str] = None
|
||||
|
||||
@@ -29,6 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||
@@ -50,7 +51,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
@@ -65,6 +69,8 @@ global_server_args_dict = {
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"device": ServerArgs.device,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
@@ -230,6 +236,7 @@ class Req:
|
||||
sampling_params: SamplingParams,
|
||||
return_logprob: bool = False,
|
||||
top_logprobs_num: int = 0,
|
||||
token_ids_logprob: List[int] = None,
|
||||
stream: bool = False,
|
||||
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
@@ -256,17 +263,24 @@ class Req:
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
# Sampling info
|
||||
if isinstance(sampling_params.custom_params, dict):
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
sampling_params.custom_params = sampling_params.custom_params | {
|
||||
"__req__": self
|
||||
}
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.custom_logit_processor = custom_logit_processor
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
self.req_pool_idx: Optional[int] = None
|
||||
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
# If we want to abort the request in the middle of the event loop, set this to true
|
||||
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond
|
||||
self.to_abort = False
|
||||
self.stream = stream
|
||||
self.eos_token_ids = eos_token_ids
|
||||
@@ -289,38 +303,56 @@ class Req:
|
||||
self.image_inputs: Optional[ImageInputs] = None
|
||||
|
||||
# Prefix info
|
||||
# The indices to kv cache for the shared prefix.
|
||||
self.prefix_indices = []
|
||||
# Tokens to run prefill. input_tokens - shared_prefix_tokens.
|
||||
# Updated if chunked.
|
||||
# Number of tokens to run prefill.
|
||||
self.extend_input_len = 0
|
||||
# The relative logprob_start_len in an extend batch
|
||||
self.extend_logprob_start_len = 0
|
||||
self.last_node = None
|
||||
|
||||
# Chunked prefill
|
||||
self.is_being_chunked = 0
|
||||
# Whether or not if it is chunked. It increments whenever
|
||||
# it is chunked, and decrement whenever chunked request is
|
||||
# processed.
|
||||
self.is_chunked = 0
|
||||
|
||||
# For retraction
|
||||
self.is_retracted = False
|
||||
|
||||
# Logprobs (arguments)
|
||||
self.return_logprob = return_logprob
|
||||
# Start index to compute logprob from.
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
self.token_ids_logprob = token_ids_logprob
|
||||
|
||||
# Logprobs (return values)
|
||||
self.input_token_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_top_logprobs_val: Optional[List[float]] = None
|
||||
self.input_top_logprobs_idx: Optional[List[int]] = None
|
||||
self.input_token_ids_logprobs_val: Optional[List[float]] = None
|
||||
self.input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||
# Temporary holder to store input_token_logprobs.
|
||||
self.input_token_logprobs: Optional[List[Tuple[int]]] = None
|
||||
self.temp_input_top_logprobs_val: Optional[List[torch.Tensor]] = None
|
||||
self.temp_input_top_logprobs_idx: Optional[List[int]] = None
|
||||
self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None
|
||||
self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None
|
||||
|
||||
if return_logprob:
|
||||
self.output_token_logprobs_val = []
|
||||
self.output_token_logprobs_idx = []
|
||||
self.output_top_logprobs_val = []
|
||||
self.output_top_logprobs_idx = []
|
||||
self.output_token_ids_logprobs_val = []
|
||||
self.output_token_ids_logprobs_idx = []
|
||||
else:
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
||||
self.output_top_logprobs_val
|
||||
) = self.output_top_logprobs_idx = None
|
||||
) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = (
|
||||
self.output_token_ids_logprobs_idx
|
||||
) = None
|
||||
self.hidden_states = []
|
||||
|
||||
# Logprobs (internal values)
|
||||
@@ -345,6 +377,13 @@ class Req:
|
||||
self.spec_verify_ct = 0
|
||||
self.lora_path = lora_path
|
||||
|
||||
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
|
||||
self.to_abort_message: str = "Unknown error"
|
||||
|
||||
@property
|
||||
def seqlen(self):
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.image_inputs is None:
|
||||
self.image_inputs = image_inputs
|
||||
@@ -422,7 +461,9 @@ class Req:
|
||||
return
|
||||
|
||||
if self.to_abort:
|
||||
self.finished_reason = FINISH_ABORT()
|
||||
self.finished_reason = FINISH_ABORT(
|
||||
message=self.to_abort_message,
|
||||
)
|
||||
return
|
||||
|
||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||
@@ -517,6 +558,8 @@ class Req:
|
||||
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
||||
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
||||
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
||||
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
|
||||
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
|
||||
self.logprob_start_len = prompt_tokens + k
|
||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||
|
||||
@@ -527,16 +570,19 @@ class Req:
|
||||
self.last_node = None
|
||||
self.extend_input_len = 0
|
||||
self.is_retracted = True
|
||||
self.input_token_logprobs = None
|
||||
self.temp_input_top_logprobs_val = None
|
||||
self.temp_input_top_logprobs_idx = None
|
||||
self.extend_logprob_start_len = 0
|
||||
self.is_chunked = 0
|
||||
self.req_pool_idx = None
|
||||
|
||||
# For incremental logprobs
|
||||
# TODO: Fix the `logprob_start_len`
|
||||
self.last_update_decode_tokens = 0
|
||||
self.logprob_start_len = 10**9
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"rid(n={self.rid}, "
|
||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
|
||||
f"Req(rid={self.rid}, "
|
||||
f"input_ids={self.origin_input_ids}, output_ids={self.output_ids})"
|
||||
)
|
||||
|
||||
|
||||
@@ -576,11 +622,13 @@ class ScheduleBatch:
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# For extend and mixed chunekd prefill
|
||||
prefix_lens: List[int] = None
|
||||
@@ -588,6 +636,8 @@ class ScheduleBatch:
|
||||
extend_num_tokens: int = None
|
||||
decoding_reqs: List[Req] = None
|
||||
extend_logprob_start_lens: List[int] = None
|
||||
# It comes empty list if logprob is not required.
|
||||
extend_input_logprob_token_ids: Optional[torch.Tensor] = None
|
||||
|
||||
# For encoder-decoder
|
||||
encoder_cached: Optional[List[bool]] = None
|
||||
@@ -606,7 +656,7 @@ class ScheduleBatch:
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
||||
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
@@ -653,8 +703,10 @@ class ScheduleBatch:
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
"Out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`."
|
||||
"alloc_req_slots runs out of memory. "
|
||||
"Please set a smaller number for `--max-running-requests`. "
|
||||
f"{self.req_to_token_pool.available_size()=}, "
|
||||
f"{num_reqs=}, "
|
||||
)
|
||||
return req_pool_indices
|
||||
|
||||
@@ -765,6 +817,7 @@ class ScheduleBatch:
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
input_embeds = []
|
||||
extend_input_logprob_token_ids = []
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
@@ -783,22 +836,64 @@ class ScheduleBatch:
|
||||
# If req.input_embeds is already a list, append its content directly
|
||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||
|
||||
if req.return_logprob:
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len, req.extend_input_len - 1
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
|
||||
)
|
||||
req.extend_logprob_start_len = extend_logprob_start_len
|
||||
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
req.extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len,
|
||||
req.extend_input_len,
|
||||
req.seqlen - 1,
|
||||
)
|
||||
else:
|
||||
req.extend_logprob_start_len = 0
|
||||
|
||||
if self.return_logprob:
|
||||
# Find input logprob token ids.
|
||||
# First, find a global index within origin_input_ids and slide it by 1
|
||||
# to compute input logprobs. It is because you need the next token
|
||||
# to compute input logprobs. E.g., (chunk size 2)
|
||||
#
|
||||
# input_logprobs = [1, 2, 3, 4]
|
||||
# fill_ids = [1, 2]
|
||||
# extend_input_logprob_token_id = [2, 3]
|
||||
#
|
||||
# Note that it can also overflow. In this case, we pad it with 0.
|
||||
# input_logprobs = [1, 2, 3, 4]
|
||||
# fill_ids = [3, 4]
|
||||
# extend_input_logprob_token_id = [4, 0]
|
||||
global_start_idx, global_end_idx = (
|
||||
len(req.prefix_indices),
|
||||
len(req.fill_ids),
|
||||
)
|
||||
# Apply logprob_start_len
|
||||
if global_start_idx < req.logprob_start_len:
|
||||
global_start_idx = req.logprob_start_len
|
||||
|
||||
logprob_token_ids = req.origin_input_ids[
|
||||
global_start_idx + 1 : global_end_idx + 1
|
||||
]
|
||||
extend_input_logprob_token_ids.extend(logprob_token_ids)
|
||||
|
||||
# We will need req.extend_input_len - req.extend_logprob_start_len number of
|
||||
# tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0.
|
||||
extend_input_logprob_token_ids.extend(
|
||||
[0]
|
||||
* (
|
||||
req.extend_input_len
|
||||
- req.extend_logprob_start_len
|
||||
- len(logprob_token_ids)
|
||||
)
|
||||
)
|
||||
|
||||
if self.return_logprob:
|
||||
extend_input_logprob_token_ids = torch.tensor(
|
||||
extend_input_logprob_token_ids
|
||||
)
|
||||
else:
|
||||
extend_input_logprob_token_ids = None
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
||||
@@ -821,10 +916,12 @@ class ScheduleBatch:
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
|
||||
# Write to req_to_token_pool
|
||||
pre_lens = torch.tensor(pre_lens, dtype=torch.int32).to(
|
||||
@@ -860,7 +957,6 @@ class ScheduleBatch:
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
enable_overlap_schedule=self.enable_overlap,
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
@@ -905,25 +1001,43 @@ class ScheduleBatch:
|
||||
|
||||
return False
|
||||
|
||||
def retract_decode(self):
|
||||
def retract_decode(self, server_args: ServerArgs):
|
||||
"""Retract the decoding requests when there is not enough memory."""
|
||||
sorted_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
# TODO(lsyin): improve retraction policy for radix cache
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (
|
||||
len(self.reqs[i].output_ids),
|
||||
-len(self.reqs[i].origin_input_ids),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
# For spec decoding, filter_batch API can only filter
|
||||
# requests from the back, so we can only retract from the back.
|
||||
# TODO(sang): Clean up finish path and support better retract
|
||||
# policy.
|
||||
if not server_args.speculative_algorithm:
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (
|
||||
len(self.reqs[i].output_ids),
|
||||
-len(self.reqs[i].origin_input_ids),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
def get_required_tokens(num_reqs: int):
|
||||
headroom_for_spec_decode = 0
|
||||
if server_args.speculative_algorithm:
|
||||
headroom_for_spec_decode += (
|
||||
num_reqs
|
||||
* server_args.speculative_eagle_topk
|
||||
* server_args.speculative_num_steps
|
||||
+ num_reqs * server_args.speculative_num_draft_tokens
|
||||
)
|
||||
return (
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
< len(sorted_indices) * global_config.retract_decode_steps
|
||||
< get_required_tokens(len(sorted_indices))
|
||||
or first_iter
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
@@ -1048,17 +1162,40 @@ class ScheduleBatch:
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
enable_overlap_schedule=self.enable_overlap,
|
||||
)
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
if self.spec_algorithm.is_eagle():
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
# `forward_batch_speculative_generation` after running draft models.
|
||||
return
|
||||
|
||||
if self.sampling_info.penalizer_orchestrator.is_required:
|
||||
if self.enable_overlap:
|
||||
# TODO: this can be slow, optimize this.
|
||||
delayed_output_ids = torch.tensor(
|
||||
[
|
||||
(
|
||||
req.output_ids[-1]
|
||||
if len(req.output_ids)
|
||||
else req.origin_input_ids[-1]
|
||||
)
|
||||
for req in self.reqs
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
delayed_output_ids
|
||||
)
|
||||
else:
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
self.output_ids.to(torch.int64)
|
||||
)
|
||||
|
||||
self.input_ids = self.output_ids
|
||||
self.output_ids = None
|
||||
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
|
||||
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
@@ -1086,14 +1223,15 @@ class ScheduleBatch:
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
being_chunked_req: Optional[Req] = None,
|
||||
chunked_req_to_exclude: Optional[Req] = None,
|
||||
keep_indices: Optional[List[int]] = None,
|
||||
):
|
||||
if keep_indices is None:
|
||||
keep_indices = [
|
||||
i
|
||||
for i in range(len(self.reqs))
|
||||
if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
|
||||
if not self.reqs[i].finished()
|
||||
and self.reqs[i] is not chunked_req_to_exclude
|
||||
]
|
||||
|
||||
if keep_indices is None or len(keep_indices) == 0:
|
||||
@@ -1105,31 +1243,34 @@ class ScheduleBatch:
|
||||
# No need to filter
|
||||
return
|
||||
|
||||
keep_indices_device = torch.tensor(keep_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = self.encoder_lens[keep_indices]
|
||||
self.encoder_lens = self.encoder_lens[keep_indices_device]
|
||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = self.req_pool_indices[new_indices]
|
||||
self.seq_lens = self.seq_lens[new_indices]
|
||||
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
self.output_ids = self.output_ids[new_indices]
|
||||
self.output_ids = self.output_ids[keep_indices_device]
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
|
||||
self.token_ids_logprobs = [self.token_ids_logprobs[i] for i in keep_indices]
|
||||
else:
|
||||
self.top_logprobs_nums = None
|
||||
self.token_ids_logprobs = None
|
||||
|
||||
self.has_stream = any(req.stream for req in self.reqs)
|
||||
self.has_grammar = any(req.grammar for req in self.reqs)
|
||||
|
||||
self.sampling_info.filter_batch(keep_indices, new_indices)
|
||||
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
||||
if self.spec_info:
|
||||
self.spec_info.filter_batch(new_indices)
|
||||
self.spec_info.filter_batch(keep_indices_device)
|
||||
|
||||
def merge_batch(self, other: "ScheduleBatch"):
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
@@ -1152,10 +1293,13 @@ class ScheduleBatch:
|
||||
self.output_ids = torch.concat([self.output_ids, other.output_ids])
|
||||
if self.return_logprob and other.return_logprob:
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.token_ids_logprobs.extend(other.token_ids_logprobs)
|
||||
elif self.return_logprob:
|
||||
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
||||
self.token_ids_logprobs.extend([None] * len(other.reqs))
|
||||
elif other.return_logprob:
|
||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
||||
self.reqs.extend(other.reqs)
|
||||
|
||||
self.return_logprob |= other.return_logprob
|
||||
@@ -1192,7 +1336,9 @@ class ScheduleBatch:
|
||||
seq_lens_sum=self.seq_lens_sum,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
token_ids_logprobs=self.token_ids_logprobs,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
@@ -1219,6 +1365,7 @@ class ScheduleBatch:
|
||||
else CaptureHiddenMode.NULL
|
||||
)
|
||||
),
|
||||
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
@@ -1262,9 +1409,11 @@ class ModelWorkerBatch:
|
||||
# For logprob
|
||||
return_logprob: bool
|
||||
top_logprobs_nums: Optional[List[int]]
|
||||
token_ids_logprobs: Optional[List[List[int]]]
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]]
|
||||
global_num_tokens_for_logprob: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
|
||||
# For extend
|
||||
@@ -1272,6 +1421,7 @@ class ModelWorkerBatch:
|
||||
extend_seq_lens: Optional[List[int]]
|
||||
extend_prefix_lens: Optional[List[int]]
|
||||
extend_logprob_start_lens: Optional[List[int]]
|
||||
extend_input_logprob_token_ids: Optional[torch.Tensor]
|
||||
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]]
|
||||
@@ -1293,7 +1443,8 @@ class ModelWorkerBatch:
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
# If set, the output of the batch contains the hidden states of the run.
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
|
||||
|
||||
|
||||
@@ -272,7 +272,7 @@ class PrefillAdder:
|
||||
|
||||
self.req_states = None
|
||||
self.can_run_list = []
|
||||
self.new_being_chunked_req = None
|
||||
self.new_chunked_req = None
|
||||
self.log_hit_tokens = 0
|
||||
self.log_input_tokens = 0
|
||||
|
||||
@@ -327,7 +327,7 @@ class PrefillAdder:
|
||||
self.log_hit_tokens += prefix_len
|
||||
self.log_input_tokens += extend_input_len
|
||||
|
||||
def add_being_chunked_req(self, req: Req):
|
||||
def add_chunked_req(self, req: Req):
|
||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||
@@ -354,7 +354,7 @@ class PrefillAdder:
|
||||
finally:
|
||||
self.tree_cache.dec_lock_ref(last_node)
|
||||
|
||||
def add_one_req_ignore_eos(self, req: Req):
|
||||
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
|
||||
def add_req_state(r, insert_sort=False):
|
||||
new_token_ratio = (
|
||||
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
||||
@@ -403,6 +403,7 @@ class PrefillAdder:
|
||||
self.rem_chunk_tokens is None
|
||||
or req.extend_input_len <= self.rem_chunk_tokens
|
||||
):
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
self._prefill_one_req(
|
||||
0,
|
||||
@@ -418,14 +419,14 @@ class PrefillAdder:
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[:trunc_len]
|
||||
self.can_run_list.append(req)
|
||||
self.new_being_chunked_req = req
|
||||
self.new_chunked_req = req
|
||||
self._prefill_one_req(0, trunc_len, 0)
|
||||
|
||||
return self.budget_state()
|
||||
|
||||
def add_one_req(self, req: Req):
|
||||
def add_one_req(self, req: Req, has_chunked_req: bool):
|
||||
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
||||
return self.add_one_req_ignore_eos(req)
|
||||
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
||||
|
||||
total_tokens = req.extend_input_len + min(
|
||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
||||
@@ -443,14 +444,7 @@ class PrefillAdder:
|
||||
if total_tokens > self.rem_total_tokens:
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if (
|
||||
self.rem_chunk_tokens is None
|
||||
or input_tokens <= self.rem_chunk_tokens
|
||||
or (
|
||||
req.return_logprob
|
||||
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
||||
)
|
||||
):
|
||||
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
@@ -470,8 +464,9 @@ class PrefillAdder:
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||
|
||||
self.can_run_list.append(req)
|
||||
self.new_being_chunked_req = req
|
||||
self.new_chunked_req = req
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -35,12 +35,12 @@ class SessionReqNode:
|
||||
for req_node in self.childs:
|
||||
req_node.clear(req_dict)
|
||||
|
||||
if self.req.finished_reason == None:
|
||||
if self.req.finished_reason is None:
|
||||
self.req.to_abort = True
|
||||
del req_dict[self.req.rid]
|
||||
|
||||
def abort(self):
|
||||
if self.req.finished_reason == None:
|
||||
if self.req.finished_reason is None:
|
||||
self.req.to_abort = True
|
||||
|
||||
def __str__(self):
|
||||
@@ -132,6 +132,10 @@ class Session:
|
||||
lora_path=req.lora_path,
|
||||
session_id=self.session_id,
|
||||
custom_logit_processor=req.custom_logit_processor,
|
||||
stream=req.stream,
|
||||
return_logprob=req.return_logprob,
|
||||
top_logprobs_num=req.top_logprobs_num,
|
||||
token_ids_logprob=req.token_ids_logprob,
|
||||
)
|
||||
if last_req is not None:
|
||||
new_req.image_inputs = last_req.image_inputs
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
@@ -24,9 +25,21 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Deque,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import fastapi
|
||||
import uvloop
|
||||
@@ -44,6 +57,7 @@ from sglang.srt.managers.image_processor import (
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
@@ -51,18 +65,25 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
GetWeightsByNameReqInput,
|
||||
GetWeightsByNameReqOutput,
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
ProfileReqOutput,
|
||||
ProfileReqType,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
SessionParams,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
@@ -98,7 +119,10 @@ class ReqState:
|
||||
|
||||
# For metrics
|
||||
created_time: float
|
||||
first_token_time: Optional[float] = None
|
||||
finished_time: float = 0.0
|
||||
first_token_time: float = 0.0
|
||||
last_time: float = 0.0
|
||||
last_completion_tokens: int = 1
|
||||
|
||||
# For streaming output
|
||||
last_output_offset: int = 0
|
||||
@@ -113,11 +137,10 @@ class TokenizerManager:
|
||||
port_args: PortArgs,
|
||||
):
|
||||
# Parse args
|
||||
|
||||
self.server_args = server_args
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
self.log_requests = server_args.log_requests
|
||||
self.log_requests_level = 0
|
||||
self.log_requests_level = server_args.log_requests_level
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
@@ -143,6 +166,7 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
self.is_generation = self.model_config.is_generation
|
||||
self.is_image_gen = self.model_config.is_image_gen
|
||||
self.context_len = self.model_config.context_len
|
||||
self.image_token_id = self.model_config.image_token_id
|
||||
|
||||
@@ -178,9 +202,12 @@ class TokenizerManager:
|
||||
# Store states
|
||||
self.no_create_loop = False
|
||||
self.rid_to_state: Dict[str, ReqState] = {}
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
self.dump_requests_threshold = 1000
|
||||
self.dump_request_list: List[Tuple] = []
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
|
||||
# The event to notify the weight sync is finished.
|
||||
self.model_update_lock = RWLock()
|
||||
@@ -192,8 +219,19 @@ class TokenizerManager:
|
||||
# For session info
|
||||
self.session_futures = {} # session_id -> asyncio event
|
||||
|
||||
# Others
|
||||
self.gracefully_exit = False
|
||||
# Set after scheduler is initialized
|
||||
self.max_req_input_len = None
|
||||
|
||||
# Metrics
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector = TokenizerMetricsCollector(
|
||||
labels={
|
||||
"model_name": self.server_args.served_model_name,
|
||||
# TODO: Add lora name/path in the future,
|
||||
},
|
||||
)
|
||||
|
||||
# Communicators
|
||||
self.init_weights_update_group_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -212,22 +250,26 @@ class TokenizerManager:
|
||||
self.resume_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
# Set after scheduler is initialized
|
||||
self.max_req_input_len = None
|
||||
|
||||
# Metrics
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector = TokenizerMetricsCollector(
|
||||
labels={
|
||||
"model_name": self.server_args.served_model_name,
|
||||
# TODO: Add lora name/path in the future,
|
||||
},
|
||||
)
|
||||
self.start_profile_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.set_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(
|
||||
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
|
||||
(
|
||||
BatchStrOut,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
BatchMultimodalOut,
|
||||
),
|
||||
self._handle_batch_output,
|
||||
),
|
||||
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
||||
@@ -259,6 +301,19 @@ class TokenizerManager:
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
self.resume_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ProfileReqOutput,
|
||||
self.start_profile_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetInternalStateReqOutput,
|
||||
self.get_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SetInternalStateReqOutput,
|
||||
self.set_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(HealthCheckOutput, lambda x: None),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -280,9 +335,9 @@ class TokenizerManager:
|
||||
obj.normalize_batch_and_arguments()
|
||||
|
||||
if self.log_requests:
|
||||
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
||||
max_length, skip_names, _ = self.log_request_metadata
|
||||
logger.info(
|
||||
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
|
||||
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
||||
)
|
||||
|
||||
async with self.model_update_lock.reader_lock:
|
||||
@@ -336,6 +391,7 @@ class TokenizerManager:
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
token_ids_logprob = obj.token_ids_logprob
|
||||
session_params = (
|
||||
SessionParams(**obj.session_params) if obj.session_params else None
|
||||
)
|
||||
@@ -378,6 +434,7 @@ class TokenizerManager:
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
token_ids_logprob,
|
||||
obj.stream,
|
||||
lora_path=obj.lora_path,
|
||||
input_embeds=input_embeds,
|
||||
@@ -401,8 +458,7 @@ class TokenizerManager:
|
||||
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
||||
created_time: Optional[float] = None,
|
||||
):
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event, obj, created_time=created_time)
|
||||
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
||||
self.rid_to_state[obj.rid] = state
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
|
||||
@@ -420,7 +476,10 @@ class TokenizerManager:
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(obj.rid)
|
||||
raise ValueError(f"Abort request {obj.rid}")
|
||||
raise ValueError(
|
||||
"Request is disconnected from the client side. "
|
||||
f"Abort request {obj.rid}"
|
||||
)
|
||||
continue
|
||||
|
||||
out = state.out_list[-1]
|
||||
@@ -428,8 +487,11 @@ class TokenizerManager:
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
if self.log_requests:
|
||||
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
|
||||
max_length, skip_names, out_skip_names = self.log_request_metadata
|
||||
if self.model_config.is_multimodal_gen:
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
||||
else:
|
||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
||||
logger.info(msg)
|
||||
del self.rid_to_state[obj.rid]
|
||||
|
||||
@@ -452,7 +514,10 @@ class TokenizerManager:
|
||||
else:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(obj.rid)
|
||||
raise ValueError(f"Abort request {obj.rid}")
|
||||
raise ValueError(
|
||||
"Request is disconnected from the client side. "
|
||||
f"Abort request {obj.rid}"
|
||||
)
|
||||
|
||||
async def _handle_batch_request(
|
||||
self,
|
||||
@@ -543,12 +608,25 @@ class TokenizerManager:
|
||||
req = AbortReq(rid)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def start_profile(self):
|
||||
req = ProfileReq.START_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
async def start_profile(
|
||||
self,
|
||||
output_dir: Optional[str] = None,
|
||||
num_steps: Optional[int] = None,
|
||||
activities: Optional[List[str]] = None,
|
||||
):
|
||||
req = ProfileReq(
|
||||
type=ProfileReqType.START_PROFILE,
|
||||
output_dir=output_dir,
|
||||
num_steps=num_steps,
|
||||
activities=activities,
|
||||
)
|
||||
result = (await self.start_profile_communicator(req))[0]
|
||||
if not result.success:
|
||||
raise RuntimeError(result.message)
|
||||
return result
|
||||
|
||||
def stop_profile(self):
|
||||
req = ProfileReq.STOP_PROFILE
|
||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
async def update_weights_from_disk(
|
||||
@@ -581,7 +659,7 @@ class TokenizerManager:
|
||||
self.server_args.model_path = obj.model_path
|
||||
self.server_args.load_format = obj.load_format
|
||||
self.model_path = obj.model_path
|
||||
return result.success, result.message
|
||||
return result.success, result.message, result.num_paused_requests
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp = []
|
||||
result = await self.model_update_result
|
||||
@@ -593,7 +671,8 @@ class TokenizerManager:
|
||||
self.model_path = obj.model_path
|
||||
all_message = [r.message for r in result]
|
||||
all_message = " | ".join(all_message)
|
||||
return all_success, all_message
|
||||
all_paused_requests = [r.num_paused_requests for r in result]
|
||||
return all_success, all_message, all_paused_requests
|
||||
|
||||
async def init_weights_update_group(
|
||||
self,
|
||||
@@ -688,6 +767,54 @@ class TokenizerManager:
|
||||
):
|
||||
await self.send_to_scheduler.send_pyobj(obj)
|
||||
|
||||
async def get_internal_state(self) -> Dict[Any, Any]:
|
||||
req = GetInternalStateReq()
|
||||
res: List[GetInternalStateReqOutput] = (
|
||||
await self.get_internal_state_communicator(req)
|
||||
)
|
||||
return res[0].internal_state
|
||||
|
||||
async def set_internal_state(
|
||||
self, obj: SetInternalStateReq
|
||||
) -> SetInternalStateReqOutput:
|
||||
res: List[SetInternalStateReqOutput] = (
|
||||
await self.set_internal_state_communicator(obj)
|
||||
)
|
||||
return res[0]
|
||||
|
||||
def get_log_request_metadata(self):
|
||||
max_length = None
|
||||
skip_names = None
|
||||
out_skip_names = None
|
||||
if self.log_requests:
|
||||
if self.log_requests_level == 0:
|
||||
max_length = 1 << 30
|
||||
skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"input_ids",
|
||||
"input_embeds",
|
||||
"image_data",
|
||||
"audio_data",
|
||||
"lora_path",
|
||||
]
|
||||
)
|
||||
out_skip_names = set(
|
||||
[
|
||||
"text",
|
||||
"output_ids",
|
||||
]
|
||||
)
|
||||
elif self.log_requests_level == 1:
|
||||
max_length = 2048
|
||||
elif self.log_requests_level == 2:
|
||||
max_length = 1 << 30
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
||||
)
|
||||
return max_length, skip_names, out_skip_names
|
||||
|
||||
def configure_logging(self, obj: ConfigureLoggingReq):
|
||||
if obj.log_requests is not None:
|
||||
self.log_requests = obj.log_requests
|
||||
@@ -698,6 +825,7 @@ class TokenizerManager:
|
||||
if obj.dump_requests_threshold is not None:
|
||||
self.dump_requests_threshold = obj.dump_requests_threshold
|
||||
logging.info(f"Config logging: {obj=}")
|
||||
self.log_request_metadata = self.get_log_request_metadata()
|
||||
|
||||
def create_abort_task(self, obj: GenerateReqInput):
|
||||
# Abort the request if the client is disconnected.
|
||||
@@ -762,15 +890,20 @@ class TokenizerManager:
|
||||
while True:
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
self._result_dispatcher(recv_obj)
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
def _handle_batch_output(
|
||||
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
|
||||
self,
|
||||
recv_obj: Union[
|
||||
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
|
||||
],
|
||||
):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
# Build meta_info and return value
|
||||
meta_info = {
|
||||
"id": rid,
|
||||
"finish_reason": recv_obj.finished_reasons[i],
|
||||
@@ -781,14 +914,12 @@ class TokenizerManager:
|
||||
self.convert_logprob_style(
|
||||
meta_info,
|
||||
state.obj.top_logprobs_num,
|
||||
state.obj.token_ids_logprob,
|
||||
state.obj.return_text_in_logprobs,
|
||||
recv_obj,
|
||||
i,
|
||||
)
|
||||
|
||||
if self.server_args.speculative_algorithm:
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
|
||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||
meta_info.update(
|
||||
{
|
||||
@@ -806,10 +937,20 @@ class TokenizerManager:
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
if self.server_args.stream_output and state.obj.stream:
|
||||
output_token_ids = recv_obj.output_ids[i][
|
||||
state.last_output_offset :
|
||||
]
|
||||
state.last_output_offset = len(recv_obj.output_ids[i])
|
||||
else:
|
||||
output_token_ids = recv_obj.output_ids[i]
|
||||
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"output_ids": output_token_ids,
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchMultimodalOut):
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
@@ -817,10 +958,17 @@ class TokenizerManager:
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reasons[i] is not None
|
||||
if state.finished:
|
||||
if self.server_args.speculative_algorithm:
|
||||
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
||||
state.finished_time = time.time()
|
||||
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
||||
|
||||
state.out_list.append(out_dict)
|
||||
state.event.set()
|
||||
|
||||
# Log metrics and dump
|
||||
if self.enable_metrics and state.obj.log_metrics:
|
||||
self.collect_metrics(state, recv_obj, i)
|
||||
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
||||
@@ -830,6 +978,7 @@ class TokenizerManager:
|
||||
self,
|
||||
meta_info: dict,
|
||||
top_logprobs_num: int,
|
||||
token_ids_logprob: List[int],
|
||||
return_text_in_logprobs: bool,
|
||||
recv_obj: BatchStrOut,
|
||||
recv_obj_index: int,
|
||||
@@ -857,6 +1006,20 @@ class TokenizerManager:
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
|
||||
if token_ids_logprob is not None:
|
||||
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.input_token_ids_logprobs_val[recv_obj_index],
|
||||
recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
meta_info["output_token_ids_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.output_token_ids_logprobs_val[recv_obj_index],
|
||||
recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
def detokenize_logprob_tokens(
|
||||
self,
|
||||
token_logprobs_val: List[float],
|
||||
@@ -900,34 +1063,30 @@ class TokenizerManager:
|
||||
else 0
|
||||
)
|
||||
|
||||
if state.first_token_time is None:
|
||||
state.first_token_time = time.time()
|
||||
if state.first_token_time == 0.0:
|
||||
state.first_token_time = state.last_time = time.time()
|
||||
state.last_completion_tokens = completion_tokens
|
||||
self.metrics_collector.observe_time_to_first_token(
|
||||
state.first_token_time - state.created_time
|
||||
)
|
||||
else:
|
||||
if completion_tokens >= 2:
|
||||
# Compute time_per_output_token for the streaming case
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.first_token_time) / (completion_tokens - 1)
|
||||
num_new_tokens = completion_tokens - state.last_completion_tokens
|
||||
if num_new_tokens:
|
||||
new_time = time.time()
|
||||
interval = new_time - state.last_time
|
||||
self.metrics_collector.observe_inter_token_latency(
|
||||
interval,
|
||||
num_new_tokens,
|
||||
)
|
||||
state.last_time = new_time
|
||||
state.last_completion_tokens = completion_tokens
|
||||
|
||||
if state.finished:
|
||||
self.metrics_collector.observe_one_finished_request(
|
||||
recv_obj.prompt_tokens[i], completion_tokens
|
||||
recv_obj.prompt_tokens[i],
|
||||
completion_tokens,
|
||||
state.finished_time - state.created_time,
|
||||
)
|
||||
self.metrics_collector.observe_e2e_request_latency(
|
||||
time.time() - state.created_time
|
||||
)
|
||||
# Compute time_per_output_token for the non-streaming case
|
||||
if (
|
||||
hasattr(state.obj, "stream")
|
||||
and not state.obj.stream
|
||||
and completion_tokens >= 1
|
||||
):
|
||||
self.metrics_collector.observe_time_per_output_token(
|
||||
(time.time() - state.created_time) / completion_tokens
|
||||
)
|
||||
|
||||
def dump_requests(self, state: ReqState, out_dict: dict):
|
||||
self.dump_request_list.append(
|
||||
@@ -996,22 +1155,38 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
class _Communicator(Generic[T]):
|
||||
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
||||
|
||||
def __init__(self, sender, fan_out: int):
|
||||
self._sender = sender
|
||||
self._fan_out = fan_out
|
||||
self._result_future: Optional[asyncio.Future] = None
|
||||
self._result_event: Optional[asyncio.Event] = None
|
||||
self._result_values: Optional[List[T]] = None
|
||||
self._ready_queue: Deque[asyncio.Future] = deque()
|
||||
|
||||
async def __call__(self, obj):
|
||||
self._sender.send_pyobj(obj)
|
||||
self._result_future = asyncio.Future()
|
||||
ready_event = asyncio.Event()
|
||||
if self._result_event is not None or len(self._ready_queue) > 0:
|
||||
self._ready_queue.append(ready_event)
|
||||
await ready_event.wait()
|
||||
assert self._result_event is None
|
||||
assert self._result_values is None
|
||||
|
||||
if obj:
|
||||
self._sender.send_pyobj(obj)
|
||||
|
||||
self._result_event = asyncio.Event()
|
||||
self._result_values = []
|
||||
await self._result_future
|
||||
await self._result_event.wait()
|
||||
result_values = self._result_values
|
||||
self._result_future = self._result_values = None
|
||||
self._result_event = self._result_values = None
|
||||
|
||||
if len(self._ready_queue) > 0:
|
||||
self._ready_queue.popleft().set()
|
||||
|
||||
return result_values
|
||||
|
||||
def handle_recv(self, recv_obj: T):
|
||||
self._result_values.append(recv_obj)
|
||||
if len(self._result_values) == self._fan_out:
|
||||
self._result_future.set_result(None)
|
||||
self._result_event.set()
|
||||
|
||||
@@ -15,10 +15,13 @@
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
@@ -159,7 +162,7 @@ class TpModelWorker:
|
||||
model_worker_batch: ModelWorkerBatch,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
skip_sample: bool = False,
|
||||
):
|
||||
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
if launch_done:
|
||||
|
||||
@@ -175,7 +175,7 @@ class TpModelWorkerClient:
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
@@ -188,8 +188,7 @@ class TpModelWorkerClient:
|
||||
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
||||
sampling_info,
|
||||
sampling_info_done=threading.Event(),
|
||||
scaling_penalties=sampling_info.scaling_penalties,
|
||||
linear_penalties=sampling_info.linear_penalties,
|
||||
penalizer_orchestrator=None,
|
||||
)
|
||||
|
||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||
|
||||
@@ -2,7 +2,9 @@ from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
@@ -12,7 +14,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ChunkCacheEntry:
|
||||
def __init__(self, rid, value):
|
||||
def __init__(self, rid: str, value: torch.Tensor):
|
||||
self.rid = rid
|
||||
self.value = value
|
||||
|
||||
@@ -24,6 +26,7 @@ class ChunkCache(BasePrefixCache):
|
||||
self.disable = True
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.entries: Dict[str, ChunkCacheEntry] = {}
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -53,11 +56,8 @@ class ChunkCache(BasePrefixCache):
|
||||
if req.rid in self.entries:
|
||||
del self.entries[req.rid]
|
||||
|
||||
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
if token_ids is None:
|
||||
token_id_len = len(req.fill_ids)
|
||||
else:
|
||||
token_id_len = len(token_ids)
|
||||
def cache_unfinished_req(self, req: Req):
|
||||
token_id_len = len(req.fill_ids)
|
||||
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :token_id_len
|
||||
@@ -86,5 +86,8 @@ class ChunkCache(BasePrefixCache):
|
||||
def evictable_size(self):
|
||||
return 0
|
||||
|
||||
def pretty_print(self):
|
||||
return ""
|
||||
|
||||
def protected_size(self):
|
||||
return 0
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""Utilities for Prometheus Metrics Collection."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Union
|
||||
|
||||
@@ -35,19 +36,20 @@ class SchedulerMetricsCollector:
|
||||
from prometheus_client import Gauge
|
||||
|
||||
self.labels = labels
|
||||
self.last_log_time = time.time()
|
||||
|
||||
self.num_running_reqs = Gauge(
|
||||
name="sglang:num_running_reqs",
|
||||
documentation="The number of running requests.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="sum",
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_used_tokens = Gauge(
|
||||
name="sglang:num_used_tokens",
|
||||
documentation="The number of used tokens.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="sum",
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.token_usage = Gauge(
|
||||
@@ -61,14 +63,14 @@ class SchedulerMetricsCollector:
|
||||
name="sglang:gen_throughput",
|
||||
documentation="The generation throughput (token/s).",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="sum",
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.num_queue_reqs = Gauge(
|
||||
name="sglang:num_queue_reqs",
|
||||
documentation="The number of requests in the waiting queue.",
|
||||
labelnames=labels.keys(),
|
||||
multiprocess_mode="sum",
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.cache_hit_rate = Gauge(
|
||||
@@ -97,6 +99,7 @@ class SchedulerMetricsCollector:
|
||||
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
||||
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
||||
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
||||
self.last_log_time = time.time()
|
||||
|
||||
|
||||
class TokenizerMetricsCollector:
|
||||
@@ -130,12 +133,15 @@ class TokenizerMetricsCollector:
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.1,
|
||||
0.25,
|
||||
0.3,
|
||||
0.5,
|
||||
0.75,
|
||||
0.7,
|
||||
0.9,
|
||||
1,
|
||||
2,
|
||||
5,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
10,
|
||||
20,
|
||||
40,
|
||||
@@ -151,24 +157,56 @@ class TokenizerMetricsCollector:
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.002,
|
||||
0.005,
|
||||
0.01,
|
||||
0.010,
|
||||
0.020,
|
||||
0.030,
|
||||
0.040,
|
||||
0.050,
|
||||
0.060,
|
||||
0.070,
|
||||
0.080,
|
||||
0.090,
|
||||
0.100,
|
||||
0.150,
|
||||
0.200,
|
||||
0.300,
|
||||
0.400,
|
||||
0.600,
|
||||
0.800,
|
||||
1.000,
|
||||
2.000,
|
||||
],
|
||||
)
|
||||
|
||||
self.histogram_inter_token_latency_seconds = Histogram(
|
||||
name="sglang:inter_token_latency_seconds",
|
||||
documentation="Histogram of inter-token latency in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.002,
|
||||
0.004,
|
||||
0.006,
|
||||
0.008,
|
||||
0.010,
|
||||
0.015,
|
||||
0.02,
|
||||
0.020,
|
||||
0.025,
|
||||
0.03,
|
||||
0.04,
|
||||
0.05,
|
||||
0.030,
|
||||
0.035,
|
||||
0.040,
|
||||
0.050,
|
||||
0.075,
|
||||
0.1,
|
||||
0.15,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
2.5,
|
||||
0.100,
|
||||
0.150,
|
||||
0.200,
|
||||
0.300,
|
||||
0.400,
|
||||
0.500,
|
||||
0.750,
|
||||
1.000,
|
||||
2.000,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -178,8 +216,9 @@ class TokenizerMetricsCollector:
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.1,
|
||||
0.25,
|
||||
0.5,
|
||||
0.2,
|
||||
0.4,
|
||||
0.8,
|
||||
1,
|
||||
2,
|
||||
5,
|
||||
@@ -188,28 +227,161 @@ class TokenizerMetricsCollector:
|
||||
40,
|
||||
60,
|
||||
80,
|
||||
100,
|
||||
150,
|
||||
200,
|
||||
250,
|
||||
300,
|
||||
350,
|
||||
500,
|
||||
1000,
|
||||
],
|
||||
)
|
||||
|
||||
self.histogram_prefill_prealloc_duration = Histogram(
|
||||
name="sglang:prefill_prealloc_duration_seconds",
|
||||
documentation="Histogram of prefill prealloc duration in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.1,
|
||||
0.3,
|
||||
0.5,
|
||||
0.7,
|
||||
0.9,
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
6,
|
||||
8,
|
||||
10,
|
||||
20,
|
||||
40,
|
||||
60,
|
||||
80,
|
||||
120,
|
||||
160,
|
||||
],
|
||||
)
|
||||
|
||||
self.histogram_prefill_queue_duration = Histogram(
|
||||
name="sglang:prefill_queue_duration_seconds",
|
||||
documentation="Histogram of prefill queue duration in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.1,
|
||||
0.3,
|
||||
0.5,
|
||||
0.7,
|
||||
0.9,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
64,
|
||||
],
|
||||
)
|
||||
|
||||
self.histogram_prefill_forward_duration = Histogram(
|
||||
name="sglang:prefill_forward_duration_seconds",
|
||||
documentation="Histogram of prefill forward duration in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.1,
|
||||
0.3,
|
||||
0.5,
|
||||
0.7,
|
||||
0.9,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
64,
|
||||
],
|
||||
)
|
||||
|
||||
self.histogram_prefill_transfer_duration = Histogram(
|
||||
name="sglang:prefill_transfer_duration_seconds",
|
||||
documentation="Histogram of prefill transfer duration in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.050,
|
||||
0.100,
|
||||
0.150,
|
||||
0.200,
|
||||
0.300,
|
||||
0.400,
|
||||
0.500,
|
||||
1.000,
|
||||
2.000,
|
||||
],
|
||||
)
|
||||
|
||||
self.histogram_decode_prealloc_duration = Histogram(
|
||||
name="sglang:decode_prealloc_duration_seconds",
|
||||
documentation="Histogram of decode prealloc duration in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.1,
|
||||
0.3,
|
||||
0.5,
|
||||
0.7,
|
||||
0.9,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
64,
|
||||
],
|
||||
)
|
||||
|
||||
self.histogram_decode_queue_duration = Histogram(
|
||||
name="sglang:decode_queue_duration_seconds",
|
||||
documentation="Histogram of decode queue duration in seconds.",
|
||||
labelnames=labels.keys(),
|
||||
buckets=[
|
||||
0.1,
|
||||
0.3,
|
||||
0.5,
|
||||
0.7,
|
||||
0.9,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
64,
|
||||
],
|
||||
)
|
||||
|
||||
def _log_histogram(self, histogram, data: Union[int, float]) -> None:
|
||||
histogram.labels(**self.labels).observe(data)
|
||||
|
||||
def _log_counter(self, counter, data: Union[int, float]) -> None:
|
||||
# Convenience function for logging to counter.
|
||||
counter.labels(**self.labels).inc(data)
|
||||
|
||||
def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int):
|
||||
def observe_one_finished_request(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
generation_tokens: int,
|
||||
e2e_latency: float,
|
||||
):
|
||||
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
||||
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
||||
self.num_requests_total.labels(**self.labels).inc(1)
|
||||
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
||||
if generation_tokens >= 1:
|
||||
self.histogram_time_per_output_token.labels(**self.labels).observe(
|
||||
e2e_latency / generation_tokens
|
||||
)
|
||||
|
||||
def observe_time_to_first_token(self, value: Union[float, int]):
|
||||
self._log_histogram(self.histogram_time_to_first_token, value)
|
||||
def observe_time_to_first_token(self, value: float):
|
||||
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
|
||||
|
||||
def observe_time_per_output_token(self, value: Union[float, int]):
|
||||
self._log_histogram(self.histogram_time_per_output_token, value)
|
||||
def observe_inter_token_latency(self, internval: float, num_new_tokens: int):
|
||||
adjusted_interval = internval / num_new_tokens
|
||||
|
||||
def observe_e2e_request_latency(self, value: Union[float, int]):
|
||||
self._log_histogram(self.histogram_e2e_request_latency, value)
|
||||
# A faster version of the Histogram::observe which observes multiple values at the same time.
|
||||
# reference: https://github.com/prometheus/client_python/blob/v0.21.1/prometheus_client/metrics.py#L639
|
||||
his = self.histogram_inter_token_latency_seconds.labels(**self.labels)
|
||||
his._sum.inc(internval)
|
||||
|
||||
for i, bound in enumerate(his._upper_bounds):
|
||||
if adjusted_interval <= bound:
|
||||
his._buckets[i].inc(num_new_tokens)
|
||||
break
|
||||
|
||||
@@ -109,11 +109,15 @@ def set_torch_compile_config():
|
||||
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
server_args = model_runner.server_args
|
||||
capture_bs = server_args.cuda_graph_bs
|
||||
|
||||
if capture_bs is None:
|
||||
if server_args.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, 33)) + [64, 128]
|
||||
if server_args.speculative_algorithm is None:
|
||||
if server_args.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
|
||||
else:
|
||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
else:
|
||||
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
capture_bs = list(range(1, 33))
|
||||
|
||||
if is_hip_:
|
||||
capture_bs += [i * 8 for i in range(21, 33)]
|
||||
@@ -130,6 +134,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
capture_bs = [
|
||||
bs
|
||||
for bs in capture_bs
|
||||
@@ -385,9 +390,6 @@ class CudaGraphRunner:
|
||||
|
||||
run_once()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
@@ -401,12 +403,11 @@ class CudaGraphRunner:
|
||||
global_graph_memory_pool = graph.pool()
|
||||
return graph, out
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
assert forward_batch.out_cache_loc is not None
|
||||
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||
hidden_mode_from_spec_info = getattr(
|
||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
)
|
||||
# If the capture_hidden_mode changes, we need to recapture the graph
|
||||
if (
|
||||
forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL
|
||||
and self.capture_hidden_mode != CaptureHiddenMode.FULL
|
||||
@@ -420,6 +421,9 @@ class CudaGraphRunner:
|
||||
self.capture_hidden_mode = hidden_mode_from_spec_info
|
||||
self.capture()
|
||||
|
||||
def replay(self, forward_batch: ForwardBatch):
|
||||
self.recapture_if_needed(forward_batch)
|
||||
|
||||
raw_bs = forward_batch.batch_size
|
||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -46,7 +46,8 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
|
||||
|
||||
class CaptureHiddenMode(IntEnum):
|
||||
NULL = auto()
|
||||
# Capture hidden states of all tokens.
|
||||
FULL = auto()
|
||||
# Capture a hidden state of the last token.
|
||||
LAST = auto()
|
||||
|
||||
def need_capture(self):
|
||||
@@ -148,6 +151,7 @@ class ForwardBatch:
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||
|
||||
# Position information
|
||||
positions: torch.Tensor = None
|
||||
@@ -160,6 +164,7 @@ class ForwardBatch:
|
||||
extend_prefix_lens_cpu: Optional[List[int]] = None
|
||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
||||
|
||||
# For multimodal
|
||||
image_inputs: Optional[List[ImageInputs]] = None
|
||||
@@ -190,10 +195,13 @@ class ForwardBatch:
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: SpecInfo = None
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
|
||||
# For padding
|
||||
padded_static_len: int = -1 # -1 if not padded
|
||||
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
|
||||
@@ -203,8 +211,13 @@ class ForwardBatch:
|
||||
batch: ModelWorkerBatch,
|
||||
model_runner: ModelRunner,
|
||||
):
|
||||
|
||||
device = model_runner.device
|
||||
extend_input_logprob_token_ids_gpu = None
|
||||
if batch.extend_input_logprob_token_ids is not None:
|
||||
extend_input_logprob_token_ids_gpu = (
|
||||
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
||||
)
|
||||
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
batch_size=len(batch.seq_lens),
|
||||
@@ -220,6 +233,7 @@ class ForwardBatch:
|
||||
seq_lens_sum=batch.seq_lens_sum,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
token_ids_logprobs=batch.token_ids_logprobs,
|
||||
global_num_tokens=batch.global_num_tokens,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
lora_paths=batch.lora_paths,
|
||||
@@ -231,6 +245,7 @@ class ForwardBatch:
|
||||
spec_info=batch.spec_info,
|
||||
capture_hidden_mode=batch.capture_hidden_mode,
|
||||
input_embeds=batch.input_embeds,
|
||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||
)
|
||||
|
||||
if ret.global_num_tokens is not None:
|
||||
@@ -341,6 +356,7 @@ class ForwardBatch:
|
||||
)
|
||||
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
|
||||
self.mrope_positions = torch.concat(
|
||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||
axis=1,
|
||||
@@ -379,7 +395,7 @@ def compute_position_kernel(
|
||||
extend_seq_lens,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(0)
|
||||
pid = tl.program_id(0).to(tl.int64)
|
||||
|
||||
prefix_len = tl.load(extend_prefix_lens + pid)
|
||||
seq_len = tl.load(extend_seq_lens + pid)
|
||||
|
||||
@@ -13,9 +13,12 @@
|
||||
# ==============================================================================
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
import collections
|
||||
import datetime
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@@ -58,6 +61,7 @@ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader import get_model
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
@@ -73,10 +77,15 @@ from sglang.srt.utils import (
|
||||
set_cpu_offload_max_bytes,
|
||||
set_cuda_arch,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
@@ -180,9 +189,13 @@ class ModelRunner:
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -199,6 +212,18 @@ class ModelRunner:
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
|
||||
# Handle the case where some of models don't finish loading.
|
||||
try:
|
||||
dist.monitored_barrier(
|
||||
group=get_tp_group().cpu_group,
|
||||
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
||||
wait_all_ranks=True,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise ValueError(
|
||||
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
||||
) from None
|
||||
|
||||
# Apply torchao quantization
|
||||
torchao_applied = getattr(self.model, "torchao_applied", False)
|
||||
# In layered loading, torchao may have been applied
|
||||
@@ -625,6 +650,9 @@ class ModelRunner:
|
||||
4096,
|
||||
)
|
||||
|
||||
if SGLANG_CI_SMALL_KV_SIZE:
|
||||
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.is_draft_worker:
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
@@ -655,6 +683,7 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
and not self.server_args.disable_mla
|
||||
@@ -758,9 +787,13 @@ class ModelRunner:
|
||||
return
|
||||
|
||||
tic = time.time()
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
logger.info(
|
||||
f"Capture cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = CudaGraphRunner(self)
|
||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
||||
logger.info(
|
||||
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def apply_torch_tp(self):
|
||||
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
||||
@@ -820,11 +853,10 @@ class ModelRunner:
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
||||
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
def _preprocess_logits(
|
||||
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
||||
):
|
||||
# Apply logit bias
|
||||
sampling_info = forward_batch.sampling_info
|
||||
if sampling_info.sampling_info_done:
|
||||
# Overlap mode: the function update_regex_vocab_mask was executed
|
||||
# in process_batch_result of the last batch.
|
||||
@@ -833,15 +865,77 @@ class ModelRunner:
|
||||
else:
|
||||
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
sampling_info.update_regex_vocab_mask()
|
||||
sampling_info.update_penalties()
|
||||
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
||||
|
||||
def update_output_logprobs(
|
||||
self,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
sampling_info: SamplingBatchInfo,
|
||||
top_logprobs_nums: List[int],
|
||||
token_ids_logprobs: List[int],
|
||||
next_token_ids: torch.Tensor,
|
||||
*,
|
||||
num_tokens_per_req: List[int],
|
||||
):
|
||||
"""Update the logits_output's output logprob based on next_token_ids
|
||||
|
||||
Args:
|
||||
logits_output: The logits output from the model forward
|
||||
sampling_info: Sampling info for logprob calculation
|
||||
top_logprobs_nums: Number of logprobs per request.
|
||||
next_token_ids: Next token ids.
|
||||
num_tokens_per_req: The number of tokens per request.
|
||||
|
||||
Returns:
|
||||
A list of next_token_ids
|
||||
"""
|
||||
self._preprocess_logits(logits_output, sampling_info)
|
||||
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
||||
top_logprobs_nums_repeat_interleaved = []
|
||||
token_ids_logprobs_repeat_interleaved = []
|
||||
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
||||
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
||||
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
||||
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
||||
self.sampler(
|
||||
logits_output,
|
||||
sampling_info,
|
||||
True,
|
||||
top_logprobs_nums_repeat_interleaved,
|
||||
token_ids_logprobs_repeat_interleaved,
|
||||
batch_next_token_ids=next_token_ids,
|
||||
)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
"""Sample and compute logprobs and update logits_output.
|
||||
|
||||
Args:
|
||||
logits_output: The logits output from the model forward
|
||||
forward_batch: The forward batch that generates logits_output
|
||||
|
||||
Returns:
|
||||
A list of next_token_ids
|
||||
"""
|
||||
# For duplex models with multiple output streams.
|
||||
if isinstance(logits_output, tuple):
|
||||
return torch.stack(
|
||||
[self.sample(values, forward_batch) for values in logits_output],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
||||
|
||||
# Sample the next tokens
|
||||
next_token_ids = self.sampler(
|
||||
logits_output,
|
||||
sampling_info,
|
||||
forward_batch.sampling_info,
|
||||
forward_batch.return_logprob,
|
||||
forward_batch.top_logprobs_nums,
|
||||
forward_batch.token_ids_logprobs,
|
||||
)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
@@ -25,10 +25,10 @@ import filelock
|
||||
import gguf
|
||||
import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
@@ -62,7 +62,6 @@ enable_hf_transfer()
|
||||
|
||||
|
||||
class DisabledTqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file(
|
||||
)
|
||||
|
||||
# check if the tensors are the same
|
||||
reloaded = load_file(sf_filename)
|
||||
reloaded = safetensors.torch.load_file(sf_filename)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file(
|
||||
def get_quant_config(
|
||||
model_config: ModelConfig, load_config: LoadConfig
|
||||
) -> QuantizationConfig:
|
||||
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
|
||||
# GGUF doesn't have config file
|
||||
@@ -402,15 +400,34 @@ def np_cache_weights_iterator(
|
||||
yield name, torch.from_numpy(param)
|
||||
|
||||
|
||||
def decrypt(fn, key):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def safetensors_encrypted_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
is_all_weights_sharded: bool = False,
|
||||
decryption_key: Optional[str] = None,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
is_all_weights_sharded: bool = False,
|
||||
decryption_key: Optional[str] = None,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files.
|
||||
|
||||
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
||||
entire file instead of reading each tensor one by one.
|
||||
"""
|
||||
if decryption_key:
|
||||
yield from safetensors_encrypted_weights_iterator(
|
||||
hf_weights_files, is_all_weights_sharded, decryption_key
|
||||
)
|
||||
return
|
||||
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
@@ -420,15 +437,9 @@ def safetensors_weights_iterator(
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
if not is_all_weights_sharded:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
else:
|
||||
result = load_file(st_file, device="cpu")
|
||||
for name, param in result.items():
|
||||
yield name, param
|
||||
result = safetensors.torch.load_file(st_file, device="cpu")
|
||||
for name, param in result.items():
|
||||
yield name, param
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from .orchestrator import BatchedPenalizerOrchestrator
|
||||
from .penalizers.frequency_penalty import BatchedFrequencyPenalizer
|
||||
from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer
|
||||
from .penalizers.presence_penalty import BatchedPresencePenalizer
|
||||
from .penalizers.repetition_penalty import BatchedRepetitionPenalizer
|
||||
from sglang.srt.sampling.penaltylib.frequency_penalty import BatchedFrequencyPenalizer
|
||||
from sglang.srt.sampling.penaltylib.min_new_tokens import BatchedMinNewTokensPenalizer
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import BatchedPenalizerOrchestrator
|
||||
from sglang.srt.sampling.penaltylib.presence_penalty import BatchedPresencePenalizer
|
||||
|
||||
__all__ = [
|
||||
"BatchedFrequencyPenalizer",
|
||||
"BatchedMinNewTokensPenalizer",
|
||||
"BatchedPresencePenalizer",
|
||||
"BatchedRepetitionPenalizer",
|
||||
"BatchedPenalizerOrchestrator",
|
||||
]
|
||||
|
||||
66
python/sglang/srt/sampling/penaltylib/frequency_penalty.py
Normal file
66
python/sglang/srt/sampling/penaltylib/frequency_penalty.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||
BatchedPenalizerOrchestrator,
|
||||
_BatchedPenalizer,
|
||||
)
|
||||
|
||||
|
||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Frequency penalizer penalizes tokens based on their frequency in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self._is_prepared = False
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.frequency_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_frequency_penalties = torch.zeros(
|
||||
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
|
||||
self.frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.frequency_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
).unsqueeze_(1)
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
self.cumulated_frequency_penalties.scatter_add_(
|
||||
dim=1,
|
||||
index=output_ids.unsqueeze(1),
|
||||
src=self.frequency_penalties,
|
||||
)
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits.sub_(self.cumulated_frequency_penalties)
|
||||
|
||||
def _filter(self, keep_indices: torch.Tensor):
|
||||
self.frequency_penalties = self.frequency_penalties[keep_indices]
|
||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||
keep_indices
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
||||
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
|
||||
self.frequency_penalties = torch.cat(
|
||||
[self.frequency_penalties, their.frequency_penalties], dim=0
|
||||
)
|
||||
self.cumulated_frequency_penalties = torch.cat(
|
||||
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||
BatchedPenalizerOrchestrator,
|
||||
_BatchedPenalizer,
|
||||
)
|
||||
|
||||
|
||||
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
@@ -10,9 +11,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
Min new tokens penalizer penalizes tokens based on the length of the output.
|
||||
"""
|
||||
|
||||
min_new_tokens: torch.Tensor = None
|
||||
stop_token_penalties: torch.Tensor = None
|
||||
len_output_tokens: torch.Tensor = None
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self._is_prepared = False
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
@@ -47,7 +48,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
self.stop_token_penalties = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
size=(len(self.orchestrator.reqs()), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
@@ -64,31 +65,22 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
]
|
||||
|
||||
self.len_output_tokens = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), 1),
|
||||
size=(len(self.orchestrator.reqs()), 1),
|
||||
dtype=torch.int32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
self.min_new_tokens = None
|
||||
self.stop_token_penalties = None
|
||||
self.len_output_tokens = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
self.len_output_tokens += 1
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
def _apply(self, logits: torch.Tensor):
|
||||
mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
|
||||
logits[mask] += self.stop_token_penalties[mask]
|
||||
return logits
|
||||
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
||||
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
||||
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
||||
def _filter(self, keep_indices: torch.Tensor):
|
||||
self.min_new_tokens = self.min_new_tokens[keep_indices]
|
||||
self.stop_token_penalties = self.stop_token_penalties[keep_indices]
|
||||
self.len_output_tokens = self.len_output_tokens[keep_indices]
|
||||
|
||||
def _merge(self, their: "BatchedMinNewTokensPenalizer"):
|
||||
self.min_new_tokens = torch.cat(
|
||||
@@ -1,35 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
from typing import List, Set, Type, Union
|
||||
from typing import TYPE_CHECKING, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _ReqLike:
|
||||
origin_input_ids: List[int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BatchLike:
|
||||
reqs: List[_ReqLike]
|
||||
|
||||
def batch_size(self):
|
||||
return len(self.reqs)
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
|
||||
class BatchedPenalizerOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
batch: _BatchLike,
|
||||
device: str,
|
||||
Penalizers: Set[Type["_BatchedPenalizer"]],
|
||||
batch: ScheduleBatch,
|
||||
penalizers: Set[Type["_BatchedPenalizer"]],
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.batch = batch
|
||||
self.device = device
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
|
||||
self.device = batch.device
|
||||
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
@@ -37,31 +27,9 @@ class BatchedPenalizerOrchestrator:
|
||||
is_required |= pen_is_required
|
||||
self.is_required = is_required
|
||||
|
||||
input_ids = [
|
||||
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
|
||||
for req in self.reqs()
|
||||
]
|
||||
if self.is_required:
|
||||
self.cumulate_input_tokens(input_ids=input_ids)
|
||||
|
||||
def reqs(self):
|
||||
return self.batch.reqs
|
||||
|
||||
def batch_size(self):
|
||||
return self.batch.batch_size()
|
||||
|
||||
def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
|
||||
"""
|
||||
Feed the input tokens to the penalizers.
|
||||
|
||||
Args:
|
||||
input_ids (List[torch.Tensor]): The input tokens.
|
||||
"""
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_input_tokens(input_ids=token_ids)
|
||||
|
||||
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
"""
|
||||
Feed the output tokens to the penalizers.
|
||||
@@ -69,13 +37,8 @@ class BatchedPenalizerOrchestrator:
|
||||
Args:
|
||||
output_ids (torch.Tensor): The output tokens.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.cumulate_output_tokens(output_ids=token_ids)
|
||||
penalizer.cumulate_output_tokens(output_ids=output_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@@ -88,48 +51,33 @@ class BatchedPenalizerOrchestrator:
|
||||
Returns:
|
||||
torch.Tensor: The logits after applying the penalizers.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
for penalizer in self.penalizers.values():
|
||||
logits = penalizer.apply(logits)
|
||||
penalizer.apply(logits)
|
||||
|
||||
return logits
|
||||
|
||||
def filter(
|
||||
self,
|
||||
indices_to_keep: List[int],
|
||||
indices_tensor_to_keep: torch.Tensor = None,
|
||||
):
|
||||
def filter(self, keep_indices: torch.Tensor):
|
||||
"""
|
||||
Filter the penalizers based on the indices to keep in the batch.
|
||||
|
||||
Args:
|
||||
indices_to_keep (List[int]): List of indices to keep in the batch.
|
||||
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
|
||||
keep_indices (torch.Tensor): Tensor of indices to keep in the batch.
|
||||
"""
|
||||
if not self.is_required:
|
||||
return
|
||||
|
||||
empty_indices = len(indices_to_keep) == 0
|
||||
if len(keep_indices) == 0:
|
||||
self.is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
penalizer.teardown()
|
||||
return
|
||||
|
||||
is_required = False
|
||||
for penalizer in self.penalizers.values():
|
||||
tmp_is_required = penalizer.is_required()
|
||||
is_required = is_required or tmp_is_required
|
||||
if not tmp_is_required or empty_indices:
|
||||
penalizer.teardown()
|
||||
is_required |= tmp_is_required
|
||||
if tmp_is_required:
|
||||
penalizer.filter(keep_indices=keep_indices)
|
||||
else:
|
||||
# create tensor index only when it's needed
|
||||
if indices_tensor_to_keep is None:
|
||||
indices_tensor_to_keep = torch.tensor(
|
||||
indices_to_keep, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
penalizer.filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
penalizer.teardown()
|
||||
self.is_required = is_required
|
||||
|
||||
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
||||
@@ -146,75 +94,9 @@ class BatchedPenalizerOrchestrator:
|
||||
if not self.is_required and not their.is_required:
|
||||
return
|
||||
|
||||
self.is_required |= their.is_required
|
||||
for Penalizer, their_penalizer in their.penalizers.items():
|
||||
if Penalizer not in self.penalizers:
|
||||
raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")
|
||||
|
||||
self.penalizers[Penalizer].merge(their_penalizer)
|
||||
|
||||
|
||||
class _TokenIDs:
|
||||
"""
|
||||
A class that wraps token IDs to provide additional utility functions to penalizers.
|
||||
|
||||
Attributes:
|
||||
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
|
||||
token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
|
||||
cached_counts (torch.Tensor): The cached occurrence count tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: BatchedPenalizerOrchestrator,
|
||||
token_ids: Union[torch.Tensor, List[torch.Tensor]],
|
||||
):
|
||||
self.orchestrator = orchestrator
|
||||
self.token_ids = token_ids
|
||||
self.cached_counts = None
|
||||
|
||||
def occurrence_count(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The occurrence count tensor.
|
||||
"""
|
||||
if self.cached_counts is not None:
|
||||
return self.cached_counts
|
||||
|
||||
token_ids = self.token_ids
|
||||
|
||||
if isinstance(token_ids, list):
|
||||
# TODO: optimize this part
|
||||
padded_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=token_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.orchestrator.vocab_size,
|
||||
)
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
).scatter_add_(
|
||||
dim=1,
|
||||
index=padded_token_ids,
|
||||
src=torch.ones_like(padded_token_ids),
|
||||
)[
|
||||
:, : self.orchestrator.vocab_size
|
||||
]
|
||||
else:
|
||||
# TODO: optimize this part. We do not need to create this big tensor every time.
|
||||
# We can directly apply the results on the logits.
|
||||
self.cached_counts = torch.zeros(
|
||||
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
self.cached_counts[
|
||||
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
|
||||
] = 1
|
||||
|
||||
return self.cached_counts
|
||||
self.is_required = True
|
||||
for penalizer, their_penalizer in their.penalizers.items():
|
||||
self.penalizers[penalizer].merge(their_penalizer)
|
||||
|
||||
|
||||
class _BatchedPenalizer(abc.ABC):
|
||||
@@ -222,10 +104,6 @@ class _BatchedPenalizer(abc.ABC):
|
||||
An abstract class for a batched penalizer.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self._is_prepared = False
|
||||
|
||||
def is_prepared(self) -> bool:
|
||||
return self._is_prepared
|
||||
|
||||
@@ -233,51 +111,40 @@ class _BatchedPenalizer(abc.ABC):
|
||||
return self._is_required()
|
||||
|
||||
def prepare(self):
|
||||
if not self.is_prepared():
|
||||
if not self._is_prepared:
|
||||
self._prepare()
|
||||
self._is_prepared = True
|
||||
|
||||
def prepare_if_required(self):
|
||||
if self.is_required():
|
||||
if self._is_required():
|
||||
self.prepare()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def teardown(self):
|
||||
if self.is_prepared():
|
||||
self._teardown()
|
||||
self._is_prepared = False
|
||||
self._is_prepared = False
|
||||
|
||||
def cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
return
|
||||
|
||||
self._cumulate_input_tokens(input_ids=input_ids)
|
||||
|
||||
def cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
if not self.is_prepared():
|
||||
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
if not self._is_prepared:
|
||||
return
|
||||
|
||||
self._cumulate_output_tokens(output_ids=output_ids)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_prepared():
|
||||
return logits
|
||||
|
||||
return self._apply(logits=logits)
|
||||
|
||||
def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
if not self.is_prepared():
|
||||
if not self._is_prepared:
|
||||
return
|
||||
|
||||
self._filter(
|
||||
indices_to_keep=indices_to_keep,
|
||||
indices_tensor_to_keep=indices_tensor_to_keep,
|
||||
)
|
||||
self._apply(logits=logits)
|
||||
|
||||
def filter(self, keep_indices: torch.Tensor):
|
||||
if not self._is_prepared:
|
||||
return
|
||||
|
||||
self._filter(keep_indices=keep_indices)
|
||||
|
||||
def merge(self, their: "_BatchedPenalizer"):
|
||||
if not self.is_prepared() and not their.is_prepared():
|
||||
if not self._is_prepared and not their._is_prepared:
|
||||
return
|
||||
|
||||
self.prepare()
|
||||
@@ -300,23 +167,7 @@ class _BatchedPenalizer(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _teardown(self):
|
||||
"""
|
||||
Tear down the penalizer.
|
||||
Usually, this is where the penalizer frees its tensors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
"""
|
||||
Cumulate the input tokens.
|
||||
Orchestrator will call this function to feed the input tokens to the penalizer.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
"""
|
||||
Cumulate the output tokens.
|
||||
Orchestrator will call this function to feed the output tokens to the penalizer.
|
||||
@@ -332,7 +183,7 @@ class _BatchedPenalizer(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
def _filter(self, keep_indices: torch.Tensor):
|
||||
"""
|
||||
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
|
||||
"""
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Frequency penalizer penalizes tokens based on their frequency in the output.
|
||||
"""
|
||||
|
||||
frequency_penalties: torch.Tensor = None
|
||||
cumulated_frequency_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.frequency_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.frequency_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.frequency_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_frequency_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
self.frequency_penalties = None
|
||||
self.cumulated_frequency_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
self.cumulated_frequency_penalties += (
|
||||
self.frequency_penalties * output_ids.occurrence_count()
|
||||
)
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_frequency_penalties
|
||||
return logits
|
||||
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
||||
self.frequency_penalties = torch.cat(
|
||||
[self.frequency_penalties, their.frequency_penalties], dim=0
|
||||
)
|
||||
self.cumulated_frequency_penalties = torch.cat(
|
||||
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -1,74 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
|
||||
|
||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Presence penalizer penalizes tokens based on their presence in the output.
|
||||
"""
|
||||
|
||||
presence_penalties: torch.Tensor = None
|
||||
cumulated_presence_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.presence_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[0.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.presence_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_presence_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
self.presence_penalties = None
|
||||
self.cumulated_presence_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
pass
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits -= self.cumulated_presence_penalties
|
||||
return logits
|
||||
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedPresencePenalizer"):
|
||||
self.presence_penalties = torch.cat(
|
||||
[self.presence_penalties, their.presence_penalties], dim=0
|
||||
)
|
||||
self.cumulated_presence_penalties = torch.cat(
|
||||
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -1,85 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def apply_scaling_penalties(logits, scaling_penalties):
|
||||
logits[:] = torch.where(
|
||||
logits > 0,
|
||||
logits / scaling_penalties,
|
||||
logits * scaling_penalties,
|
||||
)
|
||||
|
||||
|
||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Repetition penalizer penalizes tokens based on their repetition in the input and output.
|
||||
"""
|
||||
|
||||
repetition_penalties: torch.Tensor = None
|
||||
cumulated_repetition_penalties: torch.Tensor = None
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.repetition_penalty != 1.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[1.0 for _ in self.orchestrator.reqs()],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.repeat(1, self.orchestrator.vocab_size)
|
||||
)
|
||||
|
||||
self.repetition_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.repetition_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
.unsqueeze_(1)
|
||||
.expand_as(self.cumulated_repetition_penalties)
|
||||
)
|
||||
|
||||
def _teardown(self):
|
||||
self.repetition_penalties = None
|
||||
self.cumulated_repetition_penalties = None
|
||||
|
||||
def _cumulate_input_tokens(self, input_ids: _TokenIDs):
|
||||
mask = input_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: _TokenIDs):
|
||||
mask = output_ids.occurrence_count() > 0
|
||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
|
||||
return logits
|
||||
|
||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
||||
indices_tensor_to_keep
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedRepetitionPenalizer"):
|
||||
self.repetition_penalties = torch.cat(
|
||||
[self.repetition_penalties, their.repetition_penalties], dim=0
|
||||
)
|
||||
self.cumulated_repetition_penalties = torch.cat(
|
||||
[self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
|
||||
dim=0,
|
||||
)
|
||||
66
python/sglang/srt/sampling/penaltylib/presence_penalty.py
Normal file
66
python/sglang/srt/sampling/penaltylib/presence_penalty.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.orchestrator import (
|
||||
BatchedPenalizerOrchestrator,
|
||||
_BatchedPenalizer,
|
||||
)
|
||||
|
||||
|
||||
class BatchedPresencePenalizer(_BatchedPenalizer):
|
||||
"""
|
||||
Presence penalizer penalizes tokens based on their presence in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self._is_prepared = False
|
||||
|
||||
def _is_required(self) -> bool:
|
||||
return any(
|
||||
req.sampling_params.presence_penalty != 0.0
|
||||
for req in self.orchestrator.reqs()
|
||||
)
|
||||
|
||||
def _prepare(self):
|
||||
self.cumulated_presence_penalties = torch.zeros(
|
||||
(len(self.orchestrator.reqs()), self.orchestrator.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
|
||||
self.presence_penalties = (
|
||||
torch.tensor(
|
||||
data=[
|
||||
req.sampling_params.presence_penalty
|
||||
for req in self.orchestrator.reqs()
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=self.orchestrator.device,
|
||||
)
|
||||
).unsqueeze_(1)
|
||||
|
||||
def _cumulate_output_tokens(self, output_ids: torch.Tensor):
|
||||
self.cumulated_presence_penalties.scatter_(
|
||||
dim=1,
|
||||
index=output_ids.unsqueeze(1),
|
||||
src=self.presence_penalties,
|
||||
)
|
||||
|
||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
logits.sub_(self.cumulated_presence_penalties)
|
||||
|
||||
def _filter(self, keep_indices: torch.Tensor):
|
||||
self.presence_penalties = self.presence_penalties[keep_indices]
|
||||
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
||||
keep_indices
|
||||
]
|
||||
|
||||
def _merge(self, their: "BatchedPresencePenalizer"):
|
||||
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
|
||||
self.presence_penalties = torch.cat(
|
||||
[self.presence_penalties, their.presence_penalties], dim=0
|
||||
)
|
||||
self.cumulated_presence_penalties = torch.cat(
|
||||
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
||||
dim=0,
|
||||
)
|
||||
@@ -9,9 +9,6 @@ import torch
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
|
||||
apply_scaling_penalties,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,49 +19,45 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SamplingBatchInfo:
|
||||
# Batched sampling params
|
||||
# Basic batched sampling params
|
||||
temperatures: torch.Tensor
|
||||
top_ps: torch.Tensor
|
||||
top_ks: torch.Tensor
|
||||
min_ps: torch.Tensor
|
||||
|
||||
# All requests use greedy sampling
|
||||
# Whether all requests use greedy sampling
|
||||
is_all_greedy: bool
|
||||
|
||||
# Dispatch in CUDA graph
|
||||
# Whether any request needs min_p sampling
|
||||
need_min_p_sampling: bool
|
||||
|
||||
# Whether any request has custom logit processor
|
||||
has_custom_logit_processor: bool
|
||||
|
||||
# Bias Tensors
|
||||
# Masking tensors for grammar-guided structured outputs
|
||||
vocab_size: int
|
||||
grammars: Optional[List] = None
|
||||
sampling_info_done: Optional[threading.Event] = None
|
||||
logit_bias: torch.Tensor = None
|
||||
vocab_mask: Optional[torch.Tensor] = None
|
||||
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
||||
|
||||
# An event used for overlap schedule
|
||||
sampling_info_done: Optional[threading.Event] = None
|
||||
|
||||
# Penalizer
|
||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||
linear_penalties: Optional[torch.Tensor] = None
|
||||
scaling_penalties: Optional[torch.Tensor] = None
|
||||
linear_penalty: torch.Tensor = None
|
||||
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
# Custom Parameters
|
||||
# Whether any request has custom logit processor
|
||||
has_custom_logit_processor: bool = False
|
||||
# Custom parameters
|
||||
custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
|
||||
|
||||
# Custom Logit Processor
|
||||
# Custom logit processor
|
||||
custom_logit_processor: Optional[
|
||||
Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
|
||||
] = None
|
||||
|
||||
# Device
|
||||
device: str = "cuda"
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
||||
):
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
reqs = batch.reqs
|
||||
device = batch.device
|
||||
temperatures = (
|
||||
@@ -118,106 +111,60 @@ class SamplingBatchInfo:
|
||||
merged_custom_logit_processor = None
|
||||
custom_params = None
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
#
|
||||
# While we can choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||
penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
penalizers={
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
},
|
||||
)
|
||||
|
||||
ret = cls(
|
||||
temperatures=temperatures,
|
||||
top_ps=top_ps,
|
||||
top_ks=top_ks,
|
||||
min_ps=min_ps,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
||||
has_custom_logit_processor=has_custom_logit_processor,
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
vocab_size=vocab_size,
|
||||
device=device,
|
||||
penalizer_orchestrator=penalizer_orchestrator,
|
||||
has_custom_logit_processor=has_custom_logit_processor,
|
||||
custom_params=custom_params,
|
||||
custom_logit_processor=merged_custom_logit_processor,
|
||||
device=device,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
if enable_overlap_schedule:
|
||||
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
|
||||
# so it is kind of tricky to make it work with overlap scheduler.
|
||||
# It requires correcly updating the penalty logits before the sampling and syncing the events.
|
||||
# We will support them later.
|
||||
penalizers = {
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
}
|
||||
if (
|
||||
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
|
||||
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
|
||||
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
|
||||
):
|
||||
logger.warning(
|
||||
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
|
||||
"when using the default overlap scheduler. They will be ignored. "
|
||||
"Please add `--disable-overlap` when launching the server if you need these features. "
|
||||
"The speed will be slower in that case."
|
||||
)
|
||||
else:
|
||||
penalizers = {
|
||||
penaltylib.BatchedFrequencyPenalizer,
|
||||
penaltylib.BatchedMinNewTokensPenalizer,
|
||||
penaltylib.BatchedPresencePenalizer,
|
||||
penaltylib.BatchedRepetitionPenalizer,
|
||||
}
|
||||
|
||||
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
||||
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
||||
# should not add hefty computation overhead other than simple checks.
|
||||
#
|
||||
# While we choose not to even create the class instances if they are not required, this
|
||||
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
||||
# handle {filter_batch()} and {merge_batch()} cases as well.
|
||||
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
||||
vocab_size=vocab_size,
|
||||
batch=batch,
|
||||
device=batch.device,
|
||||
Penalizers=penalizers,
|
||||
)
|
||||
|
||||
# Handle logit bias but only allocate when needed
|
||||
ret.logit_bias = None
|
||||
|
||||
return ret
|
||||
|
||||
def __len__(self):
|
||||
return len(self.temperatures)
|
||||
|
||||
def update_penalties(self):
|
||||
self.scaling_penalties = None
|
||||
self.linear_penalties = None
|
||||
|
||||
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
||||
if not penalizer.is_prepared():
|
||||
continue
|
||||
|
||||
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
||||
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
||||
else:
|
||||
if self.linear_penalties is None:
|
||||
bs = self.penalizer_orchestrator.batch.batch_size()
|
||||
self.linear_penalties = torch.zeros(
|
||||
(bs, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||
|
||||
def update_regex_vocab_mask(self):
|
||||
if not self.grammars:
|
||||
self.vocab_mask = None
|
||||
self.apply_mask = None
|
||||
self.apply_mask_func = None
|
||||
return
|
||||
|
||||
# find a grammar from the list
|
||||
# Find a grammar from the list
|
||||
first_grammar = next(grammar for grammar in self.grammars if grammar)
|
||||
|
||||
# maybe we can reuse the existing mask?
|
||||
# TODO(lianmin): Maybe we can reuse the existing mask?
|
||||
self.vocab_mask = first_grammar.allocate_vocab_mask(
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=len(self.temperatures),
|
||||
device=self.device,
|
||||
)
|
||||
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
|
||||
self.apply_mask_func = (
|
||||
first_grammar.apply_vocab_mask
|
||||
) # force to use static method
|
||||
|
||||
# Apply the mask
|
||||
for i, grammar in enumerate(self.grammars):
|
||||
@@ -227,35 +174,56 @@ class SamplingBatchInfo:
|
||||
# Move the mask to the device if needed
|
||||
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||
def update_penalties(self):
|
||||
if self.penalizer_orchestrator.is_required:
|
||||
self.linear_penalty = torch.zeros(
|
||||
(len(self.temperatures), self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.temperatures.device,
|
||||
)
|
||||
self.penalizer_orchestrator.apply(self.linear_penalty)
|
||||
else:
|
||||
self.linear_penalty = None
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor):
|
||||
if self.linear_penalty is not None:
|
||||
# Used in the overlap mode
|
||||
logits.add_(self.linear_penalty)
|
||||
|
||||
if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
|
||||
# Used in the non-overlap mode
|
||||
self.penalizer_orchestrator.apply(logits)
|
||||
|
||||
if self.vocab_mask is not None:
|
||||
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
|
||||
|
||||
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
|
||||
self.penalizer_orchestrator.filter(keep_indices_device)
|
||||
|
||||
if self.has_custom_logit_processor:
|
||||
self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
|
||||
self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
"logit_bias",
|
||||
]:
|
||||
value = getattr(self, item, None)
|
||||
if value is not None: # logit_bias can be None
|
||||
setattr(self, item, value[new_indices])
|
||||
setattr(self, item, value[keep_indices_device])
|
||||
|
||||
def _filter_batch_custom_logit_processor(
|
||||
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
||||
self, keep_indices: List[int], keep_indices_device: torch.Tensor
|
||||
):
|
||||
"""Filter the custom logit processor and custom params"""
|
||||
|
||||
self.custom_logit_processor = {
|
||||
k: (p, mask[new_indices])
|
||||
k: (p, mask[keep_indices_device])
|
||||
for k, (p, mask) in self.custom_logit_processor.items()
|
||||
if any(
|
||||
mask[new_indices]
|
||||
if torch.any(
|
||||
mask[keep_indices_device]
|
||||
) # ignore the custom logit processor whose mask is all False
|
||||
}
|
||||
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
||||
self.custom_params = [self.custom_params[i] for i in keep_indices]
|
||||
|
||||
# If the custom logit processor is an empty dict, set the flag to False,
|
||||
# and set the custom logit processor and custom params to None.
|
||||
@@ -264,31 +232,6 @@ class SamplingBatchInfo:
|
||||
self.custom_params = None
|
||||
self.has_custom_logit_processor = False
|
||||
|
||||
@staticmethod
|
||||
def merge_bias_tensor(
|
||||
lhs: torch.Tensor,
|
||||
rhs: torch.Tensor,
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
default: int = 0,
|
||||
):
|
||||
# bias tensor can be None
|
||||
if lhs is not None or rhs is not None:
|
||||
shape, dtype = None, None
|
||||
if lhs is not None:
|
||||
shape, dtype = lhs.shape[1:], lhs.dtype
|
||||
else:
|
||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||
with torch.dtype(dtype):
|
||||
if lhs is None:
|
||||
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
|
||||
if rhs is None:
|
||||
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
|
||||
return torch.cat([lhs, rhs])
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def merge_custom_logit_processor(
|
||||
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
||||
@@ -332,11 +275,6 @@ class SamplingBatchInfo:
|
||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
# Merge the logit bias tensor
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
# Merge the custom logit processors and custom params lists
|
||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||
# Merge the custom logit processors
|
||||
@@ -370,22 +308,5 @@ class SamplingBatchInfo:
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor):
|
||||
# Apply logit_bias
|
||||
if self.logit_bias is not None:
|
||||
logits.add_(self.logit_bias)
|
||||
|
||||
# min-token, presence, frequency
|
||||
if self.linear_penalties is not None:
|
||||
logits.add_(self.linear_penalties)
|
||||
|
||||
# repetition
|
||||
if self.scaling_penalties is not None:
|
||||
apply_scaling_penalties(logits, self.scaling_penalties)
|
||||
|
||||
# Apply regex vocab_mask
|
||||
if self.vocab_mask is not None:
|
||||
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
|
||||
self.is_all_greedy |= other.is_all_greedy
|
||||
self.need_min_p_sampling |= other.need_min_p_sampling
|
||||
|
||||
@@ -15,15 +15,21 @@
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||
from sglang.srt.utils import (
|
||||
create_checksum,
|
||||
get_amdgpu_memory_capacity,
|
||||
get_hpu_memory_capacity,
|
||||
get_nvgpu_memory_capacity,
|
||||
@@ -43,12 +49,13 @@ class ServerArgs:
|
||||
model_path: str
|
||||
tokenizer_path: Optional[str] = None
|
||||
tokenizer_mode: str = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
load_format: str = "auto"
|
||||
trust_remote_code: bool = True
|
||||
trust_remote_code: bool = False
|
||||
dtype: str = "auto"
|
||||
kv_cache_dtype: str = "auto"
|
||||
quantization_param_path: nullable_str = None
|
||||
quantization: Optional[str] = None
|
||||
quantization_param_path: nullable_str = None
|
||||
context_length: Optional[int] = None
|
||||
device: str = "cuda"
|
||||
served_model_name: Optional[str] = None
|
||||
@@ -67,7 +74,7 @@ class ServerArgs:
|
||||
max_total_tokens: Optional[int] = None
|
||||
chunked_prefill_size: Optional[int] = None
|
||||
max_prefill_tokens: int = 16384
|
||||
schedule_policy: str = "lpm"
|
||||
schedule_policy: str = "fcfs"
|
||||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
prefill_only_one_req: bool = False
|
||||
@@ -88,6 +95,7 @@ class ServerArgs:
|
||||
log_level: str = "info"
|
||||
log_level_http: Optional[str] = None
|
||||
log_requests: bool = False
|
||||
log_requests_level: int = 0
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = False
|
||||
decode_log_interval: int = 40
|
||||
@@ -123,11 +131,13 @@ class ServerArgs:
|
||||
grammar_backend: Optional[str] = "outlines"
|
||||
|
||||
# Speculative decoding
|
||||
speculative_draft_model_path: Optional[str] = None
|
||||
speculative_algorithm: Optional[str] = None
|
||||
speculative_draft_model_path: Optional[str] = None
|
||||
speculative_num_steps: int = 5
|
||||
speculative_eagle_topk: int = 8
|
||||
speculative_num_draft_tokens: int = 64
|
||||
speculative_eagle_topk: int = 4
|
||||
speculative_num_draft_tokens: int = 8
|
||||
speculative_accept_threshold_single: float = 1.0
|
||||
speculative_accept_threshold_acc: float = 1.0
|
||||
speculative_token_map: Optional[str] = None
|
||||
|
||||
# Double Sparsity
|
||||
@@ -169,6 +179,12 @@ class ServerArgs:
|
||||
enable_hierarchical_cache: bool = False
|
||||
enable_flashinfer_mla: bool = False
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
warmups: Optional[str] = None
|
||||
|
||||
# Debug tensor dumps
|
||||
debug_tensor_dump_output_folder: Optional[str] = None
|
||||
debug_tensor_dump_input_file: Optional[str] = None
|
||||
debug_tensor_dump_inject: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
@@ -266,10 +282,10 @@ class ServerArgs:
|
||||
self.speculative_algorithm == "EAGLE"
|
||||
or self.speculative_algorithm == "NEXTN"
|
||||
):
|
||||
self.disable_overlap_schedule = True
|
||||
self.prefill_only_one_req = True
|
||||
self.disable_cuda_graph_padding = True
|
||||
self.disable_radix_cache = True
|
||||
self.disable_overlap_schedule = True
|
||||
self.chunked_prefill_size = -1
|
||||
logger.info(
|
||||
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
|
||||
@@ -377,15 +393,6 @@ class ServerArgs:
|
||||
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization-param-path",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="Path to the JSON file containing the KV cache "
|
||||
"scaling factors. This should generally be supplied, when "
|
||||
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
||||
"default to 1.0, which may cause accuracy issues. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
type=str,
|
||||
@@ -404,6 +411,15 @@ class ServerArgs:
|
||||
],
|
||||
help="The quantization method.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization-param-path",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="Path to the JSON file containing the KV cache "
|
||||
"scaling factors. This should generally be supplied, when "
|
||||
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
||||
"default to 1.0, which may cause accuracy issues. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
@@ -578,7 +594,14 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--log-requests",
|
||||
action="store_true",
|
||||
help="Log the inputs and outputs of all requests.",
|
||||
help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-requests-level",
|
||||
type=int,
|
||||
default=0,
|
||||
help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
|
||||
choices=[0, 1, 2],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-time-cost",
|
||||
@@ -742,16 +765,28 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--speculative-eagle-topk",
|
||||
type=int,
|
||||
help="The number of token sampled from draft model in eagle2 each step.",
|
||||
help="The number of tokens sampled from the draft model in eagle2 each step.",
|
||||
choices=[1, 2, 4, 8],
|
||||
default=ServerArgs.speculative_eagle_topk,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-num-draft-tokens",
|
||||
type=int,
|
||||
help="The number of token sampled from draft model in Speculative Decoding.",
|
||||
help="The number of tokens sampled from the draft model in Speculative Decoding.",
|
||||
default=ServerArgs.speculative_num_draft_tokens,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-accept-threshold-single",
|
||||
type=float,
|
||||
help="Accept a draft token if its probability in the target model is greater than this threshold.",
|
||||
default=ServerArgs.speculative_accept_threshold_single,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-accept-threshold-acc",
|
||||
type=float,
|
||||
help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
|
||||
default=ServerArgs.speculative_accept_threshold_acc,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-token-map",
|
||||
type=str,
|
||||
@@ -949,6 +984,35 @@ class ServerArgs:
|
||||
help="Enable hierarchical cache",
|
||||
)
|
||||
|
||||
# Server warmups
|
||||
parser.add_argument(
|
||||
"--warmups",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
|
||||
"will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
|
||||
)
|
||||
|
||||
# Debug tensor dumps
|
||||
parser.add_argument(
|
||||
"--debug-tensor-dump-output-folder",
|
||||
type=str,
|
||||
default=ServerArgs.debug_tensor_dump_output_folder,
|
||||
help="The output folder for dumping tensors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-tensor-dump-input-file",
|
||||
type=str,
|
||||
default=ServerArgs.debug_tensor_dump_input_file,
|
||||
help="The input filename for dumping tensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-tensor-dump-inject",
|
||||
type=str,
|
||||
default=ServerArgs.debug_tensor_dump_inject,
|
||||
help="Inject the outputs from jax as the input of every layer.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
|
||||
@@ -32,13 +32,15 @@ import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from multiprocessing import Pool
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@@ -480,6 +482,10 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
|
||||
|
||||
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
|
||||
"""Kill the process and all its child processes."""
|
||||
# Remove sigchld handler to avoid spammy logs.
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
|
||||
|
||||
if parent_pid is None:
|
||||
parent_pid = os.getpid()
|
||||
include_parent = False
|
||||
@@ -499,17 +505,14 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
|
||||
pass
|
||||
|
||||
if include_parent:
|
||||
if parent_pid == os.getpid():
|
||||
sys.exit(0)
|
||||
else:
|
||||
try:
|
||||
itself.kill()
|
||||
try:
|
||||
itself.kill()
|
||||
|
||||
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
|
||||
# so we send an additional signal to kill them.
|
||||
itself.send_signal(signal.SIGQUIT)
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
|
||||
# so we send an additional signal to kill them.
|
||||
itself.send_signal(signal.SIGQUIT)
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
|
||||
def monkey_patch_p2p_access_check():
|
||||
@@ -1215,7 +1218,11 @@ def cuda_device_count_stateless() -> int:
|
||||
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
|
||||
|
||||
def dataclass_to_string_truncated(data, max_length=2048):
|
||||
def dataclass_to_string_truncated(
|
||||
data, max_length=2048, skip_names: Optional[Set[str]] = None
|
||||
):
|
||||
if skip_names is None:
|
||||
skip_names = set()
|
||||
if isinstance(data, str):
|
||||
if len(data) > max_length:
|
||||
half_length = max_length // 2
|
||||
@@ -1234,6 +1241,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
||||
+ ", ".join(
|
||||
f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
|
||||
for k, v in data.items()
|
||||
if k not in skip_names
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
@@ -1244,6 +1252,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
||||
+ ", ".join(
|
||||
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
|
||||
for f in fields
|
||||
if f.name not in skip_names
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
@@ -1322,9 +1331,9 @@ def pyspy_dump_schedulers():
|
||||
result = subprocess.run(
|
||||
cmd, shell=True, capture_output=True, text=True, check=True
|
||||
)
|
||||
logger.info(f"Profile for PID {pid}:\n{result.stdout}")
|
||||
logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
|
||||
logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}")
|
||||
|
||||
|
||||
def kill_itself_when_parent_died():
|
||||
@@ -1448,6 +1457,10 @@ def launch_dummy_health_check_server(host, port):
|
||||
)
|
||||
|
||||
|
||||
def create_checksum(directory: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def set_cuda_arch():
|
||||
if is_flashinfer_available():
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
||||
47
python/sglang/srt/warmup.py
Normal file
47
python/sglang/srt/warmup.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
_warmup_registry = {}
|
||||
|
||||
|
||||
def warmup(name: str) -> callable:
|
||||
def decorator(fn: callable):
|
||||
_warmup_registry[name] = fn
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager):
|
||||
for warmup_name in warmup_names:
|
||||
if warmup_name not in _warmup_registry:
|
||||
logger.warning(f"Could not find custom warmup {warmup_name}")
|
||||
continue
|
||||
logger.info(f"Running warmup {warmup_name}")
|
||||
await _warmup_registry[warmup_name](tokenizer_manager)
|
||||
|
||||
|
||||
@warmup("voice_chat")
|
||||
async def voice_chat(tokenizer_manager: TokenizerManager):
|
||||
# this warms up the fused_moe triton kernels and caches them
|
||||
# if we don't do this we break real time inference for voice chat
|
||||
for i in tqdm.trange(1, 512):
|
||||
size = i * 4
|
||||
generate_req_input = GenerateReqInput(
|
||||
input_ids=(np.random.randint(2**16, size=[size])).tolist(),
|
||||
sampling_params={
|
||||
"max_new_tokens": 30,
|
||||
"temperature": 0.8,
|
||||
"stop_token_ids": [1],
|
||||
"min_p": 0.0,
|
||||
},
|
||||
)
|
||||
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
||||
@@ -15,7 +15,7 @@
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -56,6 +56,13 @@ def get_top_logprobs(logits, k):
|
||||
return logprobs
|
||||
|
||||
|
||||
def get_token_ids_logprobs(logits, token_ids):
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
del logits
|
||||
logprobs = logprobs[..., token_ids]
|
||||
return logprobs
|
||||
|
||||
|
||||
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers.util import is_sentence_transformer_model
|
||||
@@ -84,8 +91,13 @@ class ModelOutput:
|
||||
output_ids: List[int] = None
|
||||
top_input_logprobs: List[torch.Tensor] = None
|
||||
top_output_logprobs: List[torch.Tensor] = None
|
||||
top_output_logprob_idx: List[List[int]] = None
|
||||
embed_logits: List[torch.Tensor] = None
|
||||
scores: List[float] = None
|
||||
input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
|
||||
output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
|
||||
token_ids_input_logprobs: List[torch.Tensor] = None
|
||||
token_ids_output_logprobs: List[torch.Tensor] = None
|
||||
|
||||
|
||||
class HFRunner:
|
||||
@@ -157,7 +169,7 @@ class HFRunner:
|
||||
|
||||
# Run forward
|
||||
while True:
|
||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
||||
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
|
||||
if lora_paths is not None:
|
||||
assert len(prompts) == len(lora_paths)
|
||||
|
||||
@@ -165,16 +177,16 @@ class HFRunner:
|
||||
if self.model_type == "generation":
|
||||
out_queue.put(
|
||||
self.forward_generation_raw(
|
||||
base_model=self.base_model,
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
base_model=self.base_model,
|
||||
tokenizer=self.tokenizer,
|
||||
lora_paths=lora_paths,
|
||||
torch_dtype=torch_dtype,
|
||||
output_str_only=self.output_str_only,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
)
|
||||
)
|
||||
|
||||
elif self.model_type == "embedding":
|
||||
assert not self.output_str_only
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
@@ -199,10 +211,11 @@ class HFRunner:
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
token_ids_logprob: Optional[int] = None,
|
||||
):
|
||||
self.in_queue.put((prompts, max_new_tokens, lora_paths))
|
||||
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
|
||||
return self.out_queue.get()
|
||||
|
||||
def terminate(self):
|
||||
@@ -218,17 +231,24 @@ class HFRunner:
|
||||
|
||||
@staticmethod
|
||||
def forward_generation_raw(
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
base_model,
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens: int,
|
||||
tokenizer,
|
||||
lora_paths,
|
||||
torch_dtype: torch.dtype,
|
||||
output_str_only: bool,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
output_str_only: bool = False,
|
||||
token_ids_logprob: Optional[int] = None,
|
||||
) -> ModelOutput:
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
top_output_logprobs = []
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_input_logprobs = []
|
||||
token_ids_output_logprobs = []
|
||||
else:
|
||||
token_ids_input_logprobs = token_ids_output_logprobs = None
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
if isinstance(p, str):
|
||||
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
|
||||
@@ -275,18 +295,33 @@ class HFRunner:
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_output_logprobs.append(
|
||||
[
|
||||
get_token_ids_logprobs(
|
||||
logits[0], token_ids_logprob
|
||||
).tolist()
|
||||
for logits in outputs.scores
|
||||
]
|
||||
)
|
||||
del outputs
|
||||
|
||||
input_logits = model.forward(input_ids).logits[0]
|
||||
top_input_logprobs.append(
|
||||
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
||||
)
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_input_logprobs.append(
|
||||
get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
|
||||
)
|
||||
del input_logits
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
token_ids_input_logprobs=token_ids_input_logprobs,
|
||||
token_ids_output_logprobs=token_ids_output_logprobs,
|
||||
)
|
||||
|
||||
|
||||
@@ -303,11 +338,31 @@ class SRTRunner:
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
chunked_prefill_size: Optional[int] = None,
|
||||
dp_size: int = 1,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
enable_ep_moe: bool = False,
|
||||
mem_fraction_static: float = 0.65,
|
||||
trust_remote_code: bool = False,
|
||||
speculative_draft_model_path: Optional[str] = None,
|
||||
speculative_algorithm: Optional[str] = None,
|
||||
speculative_num_steps: Optional[int] = None,
|
||||
speculative_eagle_topk: Optional[int] = None,
|
||||
speculative_num_draft_tokens: Optional[int] = None,
|
||||
disable_overlap_schedule: bool = False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
enable_dp_attention = dp_size > 1
|
||||
|
||||
spec_kwargs = {}
|
||||
if speculative_draft_model_path:
|
||||
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
||||
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
||||
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
||||
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
||||
spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
|
||||
|
||||
self.engine = Engine(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
@@ -321,21 +376,41 @@ class SRTRunner:
|
||||
lora_backend=lora_backend,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
chunked_prefill_size=chunked_prefill_size,
|
||||
enable_dp_attention=enable_dp_attention,
|
||||
dp_size=dp_size,
|
||||
tokenizer_path=tokenizer_path,
|
||||
enable_ep_moe=enable_ep_moe,
|
||||
disable_overlap_schedule=disable_overlap_schedule,
|
||||
cuda_graph_max_bs=4,
|
||||
**spec_kwargs,
|
||||
)
|
||||
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
|
||||
|
||||
if tokenizer_path is None:
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_path, trust_remote_code=trust_remote_code
|
||||
)
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
logprob_start_len: int = 0,
|
||||
top_k: Optional[int] = None,
|
||||
token_ids_logprob: Optional[List[int]] = None,
|
||||
):
|
||||
if self.is_generation:
|
||||
return self.forward_generation_raw(
|
||||
engine=self.engine,
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
engine=self.engine,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_k=top_k,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
@@ -358,10 +433,10 @@ class SRTRunner:
|
||||
"""
|
||||
if self.is_generation:
|
||||
return self.batch_forward_generation_raw(
|
||||
engine=self.engine,
|
||||
prompts=prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
engine=self.engine,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
@@ -381,24 +456,43 @@ class SRTRunner:
|
||||
|
||||
@staticmethod
|
||||
def forward_generation_raw(
|
||||
engine: Engine,
|
||||
prompts: Union[List[str], List[torch.Tensor]],
|
||||
max_new_tokens,
|
||||
lora_paths,
|
||||
engine,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
logprob_start_len: int = 0,
|
||||
top_k: Optional[int] = None,
|
||||
token_ids_logprob: Optional[List[int]] = None,
|
||||
):
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
output_ids = []
|
||||
# Input logprobs. Note that the last item in input logprob is equivalent to
|
||||
# the first item in the output logprob.
|
||||
top_input_logprobs = []
|
||||
input_token_logprobs_lst = []
|
||||
top_output_logprobs = []
|
||||
output_token_logprobs_lst = []
|
||||
top_output_logprob_idx = []
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_input_logprobs = []
|
||||
token_ids_output_logprobs = []
|
||||
else:
|
||||
token_ids_input_logprobs = token_ids_output_logprobs = None
|
||||
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
if top_k:
|
||||
sampling_params["top_k"] = top_k
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
response = engine.generate(
|
||||
prompt,
|
||||
lora_path=lora_paths[i] if lora_paths else None,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
logprob_start_len=0,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
)
|
||||
text = response["text"]
|
||||
|
||||
@@ -408,12 +502,36 @@ class SRTRunner:
|
||||
"Received an empty text response. Please verify your input or model configuration."
|
||||
)
|
||||
output_strs.append(text)
|
||||
# output_ids.append(response["output_ids"])
|
||||
|
||||
input_token_logprobs = response["meta_info"]["input_token_logprobs"]
|
||||
output_token_logprobs = response["meta_info"]["output_token_logprobs"]
|
||||
# print(i, input_token_logprobs)
|
||||
# print(i, output_token_logprobs)
|
||||
logprobs = response["meta_info"]["input_top_logprobs"]
|
||||
if token_ids_logprob is not None:
|
||||
input_token_ids_logprobs = response["meta_info"][
|
||||
"input_token_ids_logprobs"
|
||||
][1:]
|
||||
else:
|
||||
input_token_ids_logprobs = None
|
||||
|
||||
num_prompt_tokens = response["meta_info"]["prompt_tokens"]
|
||||
assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
|
||||
assert len(logprobs) == num_prompt_tokens - logprob_start_len
|
||||
|
||||
# The first token logprob has no meaning in sglang.
|
||||
input_token_logprobs = input_token_logprobs[1:]
|
||||
logprobs = logprobs[1:]
|
||||
assert len(input_token_logprobs) == len(logprobs)
|
||||
|
||||
input_token_logprobs_lst.append(
|
||||
input_token_logprobs + [output_token_logprobs[0]]
|
||||
)
|
||||
output_token_logprobs_lst.append(output_token_logprobs)
|
||||
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
[[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
@@ -429,11 +547,41 @@ class SRTRunner:
|
||||
for x in response["meta_info"]["output_top_logprobs"]
|
||||
]
|
||||
)
|
||||
top_output_logprob_idx.append(
|
||||
[
|
||||
[tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["output_top_logprobs"]
|
||||
]
|
||||
)
|
||||
if token_ids_logprob is not None:
|
||||
token_ids_input_logprobs.append(
|
||||
[[tup[0] for tup in x] for x in input_token_ids_logprobs]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"][
|
||||
"output_token_ids_logprobs"
|
||||
][0]
|
||||
]
|
||||
]
|
||||
)
|
||||
token_ids_output_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x]
|
||||
for x in response["meta_info"]["output_token_ids_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs,
|
||||
output_ids=output_ids,
|
||||
top_input_logprobs=top_input_logprobs,
|
||||
top_output_logprobs=top_output_logprobs,
|
||||
input_token_logprobs_lst=input_token_logprobs_lst,
|
||||
output_token_logprobs_lst=output_token_logprobs_lst,
|
||||
top_output_logprob_idx=top_output_logprob_idx,
|
||||
token_ids_input_logprobs=token_ids_input_logprobs,
|
||||
token_ids_output_logprobs=token_ids_output_logprobs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
88
python/sglang/test/send_one.py
Normal file
88
python/sglang/test/send_one.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Run one test prompt.
|
||||
|
||||
Usage:
|
||||
python3 -m sglang.test.send_one
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def send_one_prompt(args):
|
||||
if args.image:
|
||||
args.prompt = (
|
||||
"Human: Describe this image in a very short sentence.\n\nAssistant:"
|
||||
)
|
||||
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
||||
else:
|
||||
image_data = None
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:30000/generate",
|
||||
json={
|
||||
"text": args.prompt,
|
||||
"image_data": image_data,
|
||||
"sampling_params": {
|
||||
"temperature": args.temperature,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
},
|
||||
stream=args.stream,
|
||||
)
|
||||
|
||||
if args.stream:
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
ret = json.loads(chunk[5:].strip("\n"))
|
||||
else:
|
||||
ret = response.json()
|
||||
|
||||
latency = ret["meta_info"]["e2e_latency"]
|
||||
|
||||
if "spec_verify_ct" in ret["meta_info"]:
|
||||
acc_length = (
|
||||
ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
|
||||
)
|
||||
else:
|
||||
acc_length = 1.0
|
||||
|
||||
speed = ret["meta_info"]["completion_tokens"] / latency
|
||||
|
||||
print(ret["text"])
|
||||
print()
|
||||
print(f"{acc_length=:.2f}")
|
||||
print(f"{speed=:.2f} token/s")
|
||||
|
||||
return acc_length, speed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--max-new-tokens", type=int, default=512)
|
||||
parser.add_argument("--frequency-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--presence-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
send_one_prompt(args)
|
||||
@@ -8,10 +8,11 @@ import random
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -408,26 +409,49 @@ def popen_launch_server(
|
||||
other_args: list[str] = (),
|
||||
env: Optional[dict] = None,
|
||||
return_stdout_stderr: Optional[tuple] = None,
|
||||
pd_seperated: bool = False,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
|
||||
if pd_seperated:
|
||||
command = "sglang.launch_pd_server"
|
||||
else:
|
||||
command = "sglang.launch_server"
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
command,
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
*other_args,
|
||||
*[str(x) for x in other_args],
|
||||
]
|
||||
|
||||
if pd_seperated:
|
||||
command.extend(
|
||||
[
|
||||
"--lb-host",
|
||||
host,
|
||||
"--lb-port",
|
||||
port,
|
||||
]
|
||||
)
|
||||
else:
|
||||
command.extend(
|
||||
[
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
]
|
||||
)
|
||||
|
||||
if api_key:
|
||||
command += ["--api-key", api_key]
|
||||
|
||||
print(f"command={' '.join(command)}")
|
||||
|
||||
if return_stdout_stderr:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
@@ -456,6 +480,8 @@ def popen_launch_server(
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
|
||||
kill_process_tree(process.pid)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
|
||||
@@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||
success = True
|
||||
|
||||
for filename in files:
|
||||
global process
|
||||
process = None
|
||||
|
||||
def run_one_file(filename):
|
||||
nonlocal process
|
||||
|
||||
filename = os.path.join(os.getcwd(), filename)
|
||||
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
|
||||
process = subprocess.Popen(
|
||||
@@ -534,11 +562,14 @@ def get_benchmark_args(
|
||||
dataset_path="",
|
||||
tokenizer="",
|
||||
num_prompts=500,
|
||||
sharegpt_output_len=None,
|
||||
random_input_len=4096,
|
||||
random_output_len=2048,
|
||||
sharegpt_context_len=None,
|
||||
request_rate=float("inf"),
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
pd_seperated: bool = False,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
backend="sglang",
|
||||
@@ -550,8 +581,8 @@ def get_benchmark_args(
|
||||
model=None,
|
||||
tokenizer=tokenizer,
|
||||
num_prompts=num_prompts,
|
||||
sharegpt_output_len=None,
|
||||
sharegpt_context_len=None,
|
||||
sharegpt_output_len=sharegpt_output_len,
|
||||
sharegpt_context_len=sharegpt_context_len,
|
||||
random_input_len=random_input_len,
|
||||
random_output_len=random_output_len,
|
||||
random_range_ratio=0.0,
|
||||
@@ -567,6 +598,8 @@ def get_benchmark_args(
|
||||
apply_chat_template=False,
|
||||
profile=None,
|
||||
lora_name=None,
|
||||
prompt_suffix="",
|
||||
pd_seperated=pd_seperated,
|
||||
)
|
||||
|
||||
|
||||
@@ -580,6 +613,7 @@ def run_bench_serving(
|
||||
tokenizer=None,
|
||||
random_input_len=4096,
|
||||
random_output_len=2048,
|
||||
sharegpt_context_len=None,
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
need_warmup=False,
|
||||
@@ -602,6 +636,7 @@ def run_bench_serving(
|
||||
num_prompts=num_prompts,
|
||||
random_input_len=random_input_len,
|
||||
random_output_len=random_output_len,
|
||||
sharegpt_context_len=sharegpt_context_len,
|
||||
request_rate=request_rate,
|
||||
disable_stream=disable_stream,
|
||||
disable_ignore_eos=disable_ignore_eos,
|
||||
@@ -626,6 +661,7 @@ def run_bench_serving_multi(
|
||||
other_server_args,
|
||||
benchmark_args,
|
||||
need_warmup=False,
|
||||
pd_seperated=False,
|
||||
):
|
||||
# Launch the server
|
||||
process = popen_launch_server(
|
||||
@@ -633,6 +669,7 @@ def run_bench_serving_multi(
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_server_args,
|
||||
pd_seperated=pd_seperated,
|
||||
)
|
||||
|
||||
# run benchmark for all
|
||||
@@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args):
|
||||
"128",
|
||||
"--output",
|
||||
"8",
|
||||
*other_args,
|
||||
*[str(x) for x in other_args],
|
||||
]
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
@@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None):
|
||||
stdout = open(STDOUT_FILENAME, "w")
|
||||
stderr = open(STDERR_FILENAME, "w")
|
||||
process = subprocess.Popen(
|
||||
command, stdout=stdout, stderr=stderr, env=env, text=True
|
||||
command, stdout=stdout, stderr=stdout, env=env, text=True
|
||||
)
|
||||
|
||||
# Launch a thread to stream the output
|
||||
@@ -914,3 +951,78 @@ def run_mulit_request_test(
|
||||
def write_github_step_summary(content):
|
||||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
||||
(
|
||||
input_len,
|
||||
output_len,
|
||||
temperature,
|
||||
logprob_start_len,
|
||||
return_logprob,
|
||||
top_logprobs_num,
|
||||
) = arg
|
||||
input_ids = list(range(input_len))
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
|
||||
res = response_json
|
||||
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
|
||||
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
|
||||
|
||||
# Test the number of tokens are correct
|
||||
if return_logprob:
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
|
||||
|
||||
if top_logprobs_num:
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)
|
||||
|
||||
for i in range(output_len):
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["output_top_logprobs"][i]),
|
||||
top_logprobs_num,
|
||||
)
|
||||
|
||||
# Test the top-1 tokens are the same as output tokens if temperature == 0
|
||||
if temperature == 0:
|
||||
rank = 0
|
||||
while rank < len(res["meta_info"]["output_top_logprobs"][i]):
|
||||
try:
|
||||
self.assertListEqual(
|
||||
res["meta_info"]["output_token_logprobs"][i],
|
||||
res["meta_info"]["output_top_logprobs"][i][rank],
|
||||
)
|
||||
break
|
||||
except AssertionError:
|
||||
# There's a tie. Allow the second item in this case.
|
||||
if (
|
||||
res["meta_info"]["output_top_logprobs"][i][rank][0]
|
||||
== res["meta_info"]["output_top_logprobs"][i][rank + 1][
|
||||
0
|
||||
]
|
||||
):
|
||||
rank += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user