Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -259,7 +259,9 @@ def throughput_test_once(
|
|||||||
measurement_results["total_input_tokens"]
|
measurement_results["total_input_tokens"]
|
||||||
+ measurement_results["total_output_tokens"]
|
+ measurement_results["total_output_tokens"]
|
||||||
) / latency
|
) / latency
|
||||||
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
|
measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
|
||||||
|
"last_gen_throughput"
|
||||||
|
]
|
||||||
|
|
||||||
return measurement_results
|
return measurement_results
|
||||||
|
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
|
|||||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||||
logits_output = model_runner.forward(forward_batch)
|
logits_output, _ = model_runner.forward(forward_batch)
|
||||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||||
return next_token_ids, logits_output.next_token_logits, batch
|
return next_token_ids, logits_output.next_token_logits, batch
|
||||||
|
|
||||||
@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
|
|||||||
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
_maybe_prepare_dp_attn_batch(batch, model_runner)
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
||||||
logits_output = model_runner.forward(forward_batch)
|
logits_output, _ = model_runner.forward(forward_batch)
|
||||||
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
next_token_ids = model_runner.sample(logits_output, forward_batch)
|
||||||
return next_token_ids, logits_output.next_token_logits
|
return next_token_ids, logits_output.next_token_logits
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import requests
|
|||||||
from sglang.srt.entrypoints.http_server import launch_server
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -33,9 +34,13 @@ class BenchArgs:
|
|||||||
batch_size: Tuple[int] = (1,)
|
batch_size: Tuple[int] = (1,)
|
||||||
input_len: Tuple[int] = (1024,)
|
input_len: Tuple[int] = (1024,)
|
||||||
output_len: Tuple[int] = (16,)
|
output_len: Tuple[int] = (16,)
|
||||||
|
temperature: float = 0.0
|
||||||
|
return_logprob: bool = False
|
||||||
|
input_len_step_percentage: float = 0.0
|
||||||
result_filename: str = "result.jsonl"
|
result_filename: str = "result.jsonl"
|
||||||
base_url: str = ""
|
base_url: str = ""
|
||||||
skip_warmup: bool = False
|
skip_warmup: bool = False
|
||||||
|
show_report: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
@@ -49,11 +54,19 @@ class BenchArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||||
|
parser.add_argument("--return-logprob", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-len-step-percentage",
|
||||||
|
type=float,
|
||||||
|
default=BenchArgs.input_len_step_percentage,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||||
)
|
)
|
||||||
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
||||||
parser.add_argument("--skip-warmup", action="store_true")
|
parser.add_argument("--skip-warmup", action="store_true")
|
||||||
|
parser.add_argument("--show-report", action="store_true")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
@@ -99,36 +112,89 @@ def run_one_case(
|
|||||||
batch_size: int,
|
batch_size: int,
|
||||||
input_len: int,
|
input_len: int,
|
||||||
output_len: int,
|
output_len: int,
|
||||||
|
temperature: float,
|
||||||
|
return_logprob: bool,
|
||||||
|
input_len_step_percentage: float,
|
||||||
run_name: str,
|
run_name: str,
|
||||||
result_filename: str,
|
result_filename: str,
|
||||||
):
|
):
|
||||||
input_ids = [
|
requests.post(url + "/flush_cache")
|
||||||
[int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
|
input_lens = [
|
||||||
for _ in range(batch_size)
|
int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
|
||||||
|
for i in range(batch_size)
|
||||||
]
|
]
|
||||||
|
input_ids = [
|
||||||
|
[int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
|
||||||
|
for i in range(batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
use_structured_outputs = False
|
||||||
|
if use_structured_outputs:
|
||||||
|
texts = []
|
||||||
|
for _ in range(batch_size):
|
||||||
|
texts.append(
|
||||||
|
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
|
||||||
|
* 50
|
||||||
|
+ "Assistant:"
|
||||||
|
)
|
||||||
|
json_schema = "$$ANY$$"
|
||||||
|
else:
|
||||||
|
json_schema = None
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
|
# "text": texts,
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": temperature,
|
||||||
"max_new_tokens": output_len,
|
"max_new_tokens": output_len,
|
||||||
"ignore_eos": True,
|
"ignore_eos": True,
|
||||||
|
"json_schema": json_schema,
|
||||||
},
|
},
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"stream": True,
|
||||||
},
|
},
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
latency = time.time() - tic
|
|
||||||
|
|
||||||
_ = response.json()
|
# The TTFT of the last request in the batch
|
||||||
output_throughput = batch_size * output_len / latency
|
ttft = 0.0
|
||||||
|
for chunk in response.iter_lines(decode_unicode=False):
|
||||||
|
chunk = chunk.decode("utf-8")
|
||||||
|
if chunk and chunk.startswith("data:"):
|
||||||
|
if chunk == "data: [DONE]":
|
||||||
|
break
|
||||||
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
|
if "error" in data:
|
||||||
|
raise RuntimeError(f"Request has failed. {data}.")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
data["meta_info"]["finish_reason"] is None
|
||||||
|
or data["meta_info"]["finish_reason"]["type"] == "length"
|
||||||
|
)
|
||||||
|
if data["meta_info"]["completion_tokens"] == 1:
|
||||||
|
ttft = time.time() - tic
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
input_throughput = batch_size * input_len / ttft
|
||||||
|
output_throughput = batch_size * output_len / (latency - ttft)
|
||||||
overall_throughput = batch_size * (input_len + output_len) / latency
|
overall_throughput = batch_size * (input_len + output_len) / latency
|
||||||
|
|
||||||
|
server_info = requests.get(url + "/get_server_info").json()
|
||||||
|
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
|
||||||
|
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
|
||||||
|
|
||||||
print(f"batch size: {batch_size}")
|
print(f"batch size: {batch_size}")
|
||||||
|
print(f"input_len: {input_len}")
|
||||||
|
print(f"output_len: {output_len}")
|
||||||
print(f"latency: {latency:.2f} s")
|
print(f"latency: {latency:.2f} s")
|
||||||
print(f"output throughput: {output_throughput:.2f} token/s")
|
print(f"ttft: {ttft:.2f} s")
|
||||||
print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
|
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
|
||||||
|
print(f"Input throughput: {input_throughput:.2f} tok/s")
|
||||||
|
if output_len != 1:
|
||||||
|
print(f"output throughput: {output_throughput:.2f} tok/s")
|
||||||
|
|
||||||
if result_filename:
|
if result_filename:
|
||||||
with open(result_filename, "a") as fout:
|
with open(result_filename, "a") as fout:
|
||||||
@@ -140,9 +206,21 @@ def run_one_case(
|
|||||||
"latency": round(latency, 4),
|
"latency": round(latency, 4),
|
||||||
"output_throughput": round(output_throughput, 2),
|
"output_throughput": round(output_throughput, 2),
|
||||||
"overall_throughput": round(overall_throughput, 2),
|
"overall_throughput": round(overall_throughput, 2),
|
||||||
|
"last_gen_throughput": round(last_gen_throughput, 2),
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(res) + "\n")
|
fout.write(json.dumps(res) + "\n")
|
||||||
|
|
||||||
|
return (
|
||||||
|
batch_size,
|
||||||
|
latency,
|
||||||
|
ttft,
|
||||||
|
input_throughput,
|
||||||
|
output_throughput,
|
||||||
|
overall_throughput,
|
||||||
|
last_gen_throughput,
|
||||||
|
acc_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||||
if bench_args.base_url:
|
if bench_args.base_url:
|
||||||
@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
if not bench_args.skip_warmup:
|
if not bench_args.skip_warmup:
|
||||||
|
print("=" * 8 + " Warmup Begin " + "=" * 8)
|
||||||
run_one_case(
|
run_one_case(
|
||||||
base_url,
|
base_url,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
input_len=1024,
|
input_len=1024,
|
||||||
output_len=16,
|
output_len=16,
|
||||||
|
temperature=bench_args.temperature,
|
||||||
|
return_logprob=bench_args.return_logprob,
|
||||||
|
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||||
run_name="",
|
run_name="",
|
||||||
result_filename="",
|
result_filename="",
|
||||||
)
|
)
|
||||||
|
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
||||||
|
|
||||||
# benchmark
|
# benchmark
|
||||||
|
result = []
|
||||||
try:
|
try:
|
||||||
for bs, il, ol in itertools.product(
|
for bs, il, ol in itertools.product(
|
||||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||||
):
|
):
|
||||||
run_one_case(
|
result.append(
|
||||||
base_url,
|
run_one_case(
|
||||||
bs,
|
base_url,
|
||||||
il,
|
bs,
|
||||||
ol,
|
il,
|
||||||
bench_args.run_name,
|
ol,
|
||||||
bench_args.result_filename,
|
temperature=bench_args.temperature,
|
||||||
|
return_logprob=bench_args.return_logprob,
|
||||||
|
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||||
|
run_name=bench_args.run_name,
|
||||||
|
result_filename=bench_args.result_filename,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if proc:
|
if proc:
|
||||||
@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|||||||
|
|
||||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||||
|
|
||||||
|
if not bench_args.show_report:
|
||||||
|
return
|
||||||
|
|
||||||
|
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
|
||||||
|
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
|
||||||
|
|
||||||
|
for (
|
||||||
|
batch_size,
|
||||||
|
latency,
|
||||||
|
ttft,
|
||||||
|
input_throughput,
|
||||||
|
output_throughput,
|
||||||
|
overall_throughput,
|
||||||
|
last_gen_throughput,
|
||||||
|
acc_length,
|
||||||
|
) in result:
|
||||||
|
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
|
||||||
|
input_util = 0.7
|
||||||
|
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
|
||||||
|
line = (
|
||||||
|
f"| {batch_size} | "
|
||||||
|
f"{latency:.2f} | "
|
||||||
|
f"{input_throughput:.2f} | "
|
||||||
|
f"{output_throughput:.2f} | "
|
||||||
|
f"{accept_length} | "
|
||||||
|
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
|
||||||
|
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
|
||||||
|
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
|
||||||
|
)
|
||||||
|
summary += line
|
||||||
|
|
||||||
|
# print metrics table
|
||||||
|
print(summary)
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
write_github_step_summary(
|
||||||
|
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
@@ -1103,7 +1103,7 @@ async def benchmark(
|
|||||||
lora_names: List[str],
|
lora_names: List[str],
|
||||||
extra_request_body: Dict[str, Any],
|
extra_request_body: Dict[str, Any],
|
||||||
profile: bool,
|
profile: bool,
|
||||||
pd_seperated: bool = False,
|
pd_separated: bool = False,
|
||||||
flush_cache: bool = False,
|
flush_cache: bool = False,
|
||||||
warmup_requests: int = 1,
|
warmup_requests: int = 1,
|
||||||
):
|
):
|
||||||
@@ -1239,12 +1239,14 @@ async def benchmark(
|
|||||||
|
|
||||||
if "sglang" in backend:
|
if "sglang" in backend:
|
||||||
server_info = requests.get(base_url + "/get_server_info")
|
server_info = requests.get(base_url + "/get_server_info")
|
||||||
if pd_seperated:
|
if pd_separated:
|
||||||
accept_length = server_info.json()["decode"][0].get(
|
accept_length = server_info.json()["decode"][0]["internal_states"][0].get(
|
||||||
"avg_spec_accept_length", None
|
"avg_spec_accept_length", None
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
accept_length = server_info.json().get("avg_spec_accept_length", None)
|
accept_length = server_info.json()["internal_states"][0].get(
|
||||||
|
"avg_spec_accept_length", None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
accept_length = None
|
accept_length = None
|
||||||
|
|
||||||
@@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
lora_names=args.lora_name,
|
lora_names=args.lora_name,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
profile=args.profile,
|
profile=args.profile,
|
||||||
pd_seperated=args.pd_seperated,
|
pd_separated=args.pd_separated,
|
||||||
flush_cache=args.flush_cache,
|
flush_cache=args.flush_cache,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -37,6 +37,12 @@ class BaseGrammarObject:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def rollback(self, k: int):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def is_terminated(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def allocate_vocab_mask(
|
def allocate_vocab_mask(
|
||||||
self, vocab_size: int, batch_size: int, device
|
self, vocab_size: int, batch_size: int, device
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|||||||
@@ -277,19 +277,17 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
next_token_ids,
|
next_token_ids,
|
||||||
extend_input_len_per_req,
|
extend_input_len_per_req,
|
||||||
extend_logprob_start_len_per_req,
|
extend_logprob_start_len_per_req,
|
||||||
bid,
|
|
||||||
) = (
|
) = (
|
||||||
result.logits_output,
|
result.logits_output,
|
||||||
result.next_token_ids,
|
result.next_token_ids,
|
||||||
result.extend_input_len_per_req,
|
result.extend_input_len_per_req,
|
||||||
result.extend_logprob_start_len_per_req,
|
result.extend_logprob_start_len_per_req,
|
||||||
result.bid,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
# wait
|
# wait
|
||||||
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
|
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
|
||||||
else:
|
else:
|
||||||
next_token_ids = result.next_token_ids.tolist()
|
next_token_ids = result.next_token_ids.tolist()
|
||||||
|
|
||||||
|
|||||||
@@ -330,7 +330,7 @@ class Engine(EngineBase):
|
|||||||
return {
|
return {
|
||||||
**dataclasses.asdict(self.tokenizer_manager.server_args),
|
**dataclasses.asdict(self.tokenizer_manager.server_args),
|
||||||
**self.scheduler_info,
|
**self.scheduler_info,
|
||||||
**internal_states,
|
"internal_states": internal_states,
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ async def get_server_info():
|
|||||||
return {
|
return {
|
||||||
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
||||||
**_global_state.scheduler_info,
|
**_global_state.scheduler_info,
|
||||||
**internal_states,
|
"internal_states": internal_states,
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
|
|||||||
|
|
||||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||||
for i in range(num_loop):
|
for i in range(num_loop):
|
||||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
# index into req_to_token_ptr needs to be int64
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
|
||||||
mask = offset < kv_end - kv_start
|
mask = offset < kv_end - kv_start
|
||||||
data = tl.load(
|
data = tl.load(
|
||||||
req_to_token_ptr
|
req_to_token_ptr
|
||||||
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
|
|||||||
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||||
|
|
||||||
for i in range(num_pages_loop):
|
for i in range(num_pages_loop):
|
||||||
|
# index into req_to_token_ptr needs to be int64
|
||||||
paged_offset = (
|
paged_offset = (
|
||||||
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
|
||||||
) * PAGED_SIZE
|
) * PAGED_SIZE
|
||||||
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
||||||
|
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ class GenerationBatchResult:
|
|||||||
extend_input_len_per_req: List[int]
|
extend_input_len_per_req: List[int]
|
||||||
extend_logprob_start_len_per_req: List[int]
|
extend_logprob_start_len_per_req: List[int]
|
||||||
bid: int
|
bid: int
|
||||||
|
can_run_cuda_graph: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -323,13 +324,14 @@ class Scheduler(
|
|||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
|
|
||||||
# Print debug info
|
# Print debug info
|
||||||
logger.info(
|
if tp_rank == 0:
|
||||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
logger.info(
|
||||||
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
|
||||||
f"max_running_requests={self.max_running_requests}, "
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||||
f"context_len={self.model_config.context_len}"
|
f"max_running_requests={self.max_running_requests}, "
|
||||||
)
|
f"context_len={self.model_config.context_len}"
|
||||||
|
)
|
||||||
|
|
||||||
# Init memory pool and cache
|
# Init memory pool and cache
|
||||||
self.init_memory_pool_and_cache()
|
self.init_memory_pool_and_cache()
|
||||||
@@ -752,6 +754,7 @@ class Scheduler(
|
|||||||
extend_input_len_per_req=None,
|
extend_input_len_per_req=None,
|
||||||
extend_logprob_start_len_per_req=None,
|
extend_logprob_start_len_per_req=None,
|
||||||
bid=bids[next_mb_id],
|
bid=bids[next_mb_id],
|
||||||
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||||
@@ -1159,7 +1162,9 @@ class Scheduler(
|
|||||||
|
|
||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
|
||||||
def log_decode_stats(self, running_batch=None):
|
def log_decode_stats(
|
||||||
|
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
||||||
|
):
|
||||||
batch = running_batch or self.running_batch
|
batch = running_batch or self.running_batch
|
||||||
|
|
||||||
gap_latency = time.time() - self.last_decode_stats_tic
|
gap_latency = time.time() - self.last_decode_stats_tic
|
||||||
@@ -1199,6 +1204,7 @@ class Scheduler(
|
|||||||
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
||||||
|
|
||||||
msg += (
|
msg += (
|
||||||
|
f"cuda graph: {can_run_cuda_graph}, "
|
||||||
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
||||||
f"#queue-req: {len(self.waiting_queue)}"
|
f"#queue-req: {len(self.waiting_queue)}"
|
||||||
)
|
)
|
||||||
@@ -1524,11 +1530,11 @@ class Scheduler(
|
|||||||
if self.spec_algorithm.is_none():
|
if self.spec_algorithm.is_none():
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
logits_output, next_token_ids = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pp_hidden_states_proxy_tensors, _ = (
|
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
|
||||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||||
)
|
)
|
||||||
bid = model_worker_batch.bid
|
bid = model_worker_batch.bid
|
||||||
@@ -1538,6 +1544,7 @@ class Scheduler(
|
|||||||
next_token_ids,
|
next_token_ids,
|
||||||
bid,
|
bid,
|
||||||
num_accepted_tokens,
|
num_accepted_tokens,
|
||||||
|
can_run_cuda_graph,
|
||||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||||
self.spec_num_total_accepted_tokens += (
|
self.spec_num_total_accepted_tokens += (
|
||||||
num_accepted_tokens + batch.batch_size()
|
num_accepted_tokens + batch.batch_size()
|
||||||
@@ -1571,6 +1578,7 @@ class Scheduler(
|
|||||||
extend_input_len_per_req=extend_input_len_per_req,
|
extend_input_len_per_req=extend_input_len_per_req,
|
||||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||||
bid=bid,
|
bid=bid,
|
||||||
|
can_run_cuda_graph=can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
|||||||
@@ -38,20 +38,16 @@ class SchedulerOutputProcessorMixin:
|
|||||||
next_token_ids,
|
next_token_ids,
|
||||||
extend_input_len_per_req,
|
extend_input_len_per_req,
|
||||||
extend_logprob_start_len_per_req,
|
extend_logprob_start_len_per_req,
|
||||||
bid,
|
|
||||||
) = (
|
) = (
|
||||||
result.logits_output,
|
result.logits_output,
|
||||||
result.next_token_ids,
|
result.next_token_ids,
|
||||||
result.extend_input_len_per_req,
|
result.extend_input_len_per_req,
|
||||||
result.extend_logprob_start_len_per_req,
|
result.extend_logprob_start_len_per_req,
|
||||||
result.bid,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
logits_output, next_token_ids = (
|
logits_output, next_token_ids, _ = (
|
||||||
self.tp_worker.resolve_last_batch_result(
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||||
launch_done,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Move next_token_ids and logprobs to cpu
|
# Move next_token_ids and logprobs to cpu
|
||||||
@@ -189,16 +185,16 @@ class SchedulerOutputProcessorMixin:
|
|||||||
result: GenerationBatchResult,
|
result: GenerationBatchResult,
|
||||||
launch_done: Optional[threading.Event] = None,
|
launch_done: Optional[threading.Event] = None,
|
||||||
):
|
):
|
||||||
logits_output, next_token_ids, bid = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
result.logits_output,
|
result.logits_output,
|
||||||
result.next_token_ids,
|
result.next_token_ids,
|
||||||
result.bid,
|
result.can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
self.num_generated_tokens += len(batch.reqs)
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
launch_done
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||||
)
|
)
|
||||||
next_token_logprobs = logits_output.next_token_logprobs
|
next_token_logprobs = logits_output.next_token_logprobs
|
||||||
elif batch.spec_algorithm.is_none():
|
elif batch.spec_algorithm.is_none():
|
||||||
@@ -280,7 +276,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
self.attn_tp_rank == 0
|
self.attn_tp_rank == 0
|
||||||
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
||||||
):
|
):
|
||||||
self.log_decode_stats(running_batch=batch)
|
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
|
||||||
|
|
||||||
def add_input_logprob_return_values(
|
def add_input_logprob_return_values(
|
||||||
self: Scheduler,
|
self: Scheduler,
|
||||||
|
|||||||
@@ -923,12 +923,13 @@ class TokenizerManager:
|
|||||||
):
|
):
|
||||||
await self.send_to_scheduler.send_pyobj(obj)
|
await self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
|
||||||
async def get_internal_state(self) -> Dict[Any, Any]:
|
async def get_internal_state(self) -> List[Dict[Any, Any]]:
|
||||||
req = GetInternalStateReq()
|
req = GetInternalStateReq()
|
||||||
res: List[GetInternalStateReqOutput] = (
|
responses: List[GetInternalStateReqOutput] = (
|
||||||
await self.get_internal_state_communicator(req)
|
await self.get_internal_state_communicator(req)
|
||||||
)
|
)
|
||||||
return res[0].internal_state
|
# Many DP ranks
|
||||||
|
return [res.internal_state for res in responses]
|
||||||
|
|
||||||
def get_log_request_metadata(self):
|
def get_log_request_metadata(self):
|
||||||
max_length = None
|
max_length = None
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
||||||
from sglang.srt.hf_transformers_utils import (
|
from sglang.srt.hf_transformers_utils import (
|
||||||
get_processor,
|
get_processor,
|
||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
@@ -183,8 +183,11 @@ class TpModelWorker:
|
|||||||
def forward_batch_generation(
|
def forward_batch_generation(
|
||||||
self,
|
self,
|
||||||
model_worker_batch: ModelWorkerBatch,
|
model_worker_batch: ModelWorkerBatch,
|
||||||
|
launch_done: Optional[threading.Event] = None,
|
||||||
skip_sample: bool = False,
|
skip_sample: bool = False,
|
||||||
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
|
) -> Tuple[
|
||||||
|
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
||||||
|
]:
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
|
|
||||||
pp_proxy_tensors = None
|
pp_proxy_tensors = None
|
||||||
@@ -196,11 +199,11 @@ class TpModelWorker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
logits_output = self.model_runner.forward(
|
logits_output, can_run_cuda_graph = self.model_runner.forward(
|
||||||
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
||||||
)
|
)
|
||||||
if model_worker_batch.launch_done is not None:
|
if launch_done is not None:
|
||||||
model_worker_batch.launch_done.set()
|
launch_done.set()
|
||||||
|
|
||||||
if skip_sample:
|
if skip_sample:
|
||||||
next_token_ids = None
|
next_token_ids = None
|
||||||
@@ -209,17 +212,17 @@ class TpModelWorker:
|
|||||||
logits_output, model_worker_batch
|
logits_output, model_worker_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids, can_run_cuda_graph
|
||||||
else:
|
else:
|
||||||
pp_proxy_tensors = self.model_runner.forward(
|
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
||||||
forward_batch,
|
forward_batch,
|
||||||
pp_proxy_tensors=pp_proxy_tensors,
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
return pp_proxy_tensors.tensors, None
|
return pp_proxy_tensors.tensors, None, can_run_cuda_graph
|
||||||
|
|
||||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output, _ = self.model_runner.forward(forward_batch)
|
||||||
embeddings = logits_output.embeddings
|
embeddings = logits_output.embeddings
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import logging
|
|||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
@@ -145,8 +145,10 @@ class TpModelWorkerClient:
|
|||||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
model_worker_batch
|
self.worker.forward_batch_generation(
|
||||||
|
model_worker_batch, model_worker_batch.launch_done
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the future token ids map
|
# Update the future token ids map
|
||||||
@@ -171,14 +173,18 @@ class TpModelWorkerClient:
|
|||||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||||
copy_done.record()
|
copy_done.record()
|
||||||
|
|
||||||
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
self.output_queue.put(
|
||||||
|
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
|
||||||
|
)
|
||||||
|
|
||||||
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
||||||
"""
|
"""
|
||||||
This function is called to resolve the last batch result and
|
This function is called to resolve the last batch result and
|
||||||
wait for the current batch to be launched. Used in overlap mode.
|
wait for the current batch to be launched. Used in overlap mode.
|
||||||
"""
|
"""
|
||||||
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
|
self.output_queue.get()
|
||||||
|
)
|
||||||
|
|
||||||
if launch_done is not None:
|
if launch_done is not None:
|
||||||
launch_done.wait()
|
launch_done.wait()
|
||||||
@@ -193,9 +199,11 @@ class TpModelWorkerClient:
|
|||||||
logits_output.input_token_logprobs.tolist()
|
logits_output.input_token_logprobs.tolist()
|
||||||
)
|
)
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids, can_run_cuda_graph
|
||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_generation(
|
||||||
|
self, model_worker_batch: ModelWorkerBatch
|
||||||
|
) -> Tuple[None, torch.Tensor, bool]:
|
||||||
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
||||||
sampling_info = model_worker_batch.sampling_info
|
sampling_info = model_worker_batch.sampling_info
|
||||||
sampling_info.update_penalties()
|
sampling_info.update_penalties()
|
||||||
@@ -223,7 +231,7 @@ class TpModelWorkerClient:
|
|||||||
self.future_token_ids_ct = (
|
self.future_token_ids_ct = (
|
||||||
self.future_token_ids_ct + bs
|
self.future_token_ids_ct + bs
|
||||||
) % self.future_token_ids_limit
|
) % self.future_token_ids_limit
|
||||||
return None, future_next_token_ids
|
return None, future_next_token_ids, False
|
||||||
|
|
||||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||||
success, message = self.worker.update_weights_from_disk(recv_req)
|
success, message = self.worker.update_weights_from_disk(recv_req)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import bisect
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
@@ -40,15 +40,12 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
get_device_memory_capacity,
|
get_device_memory_capacity,
|
||||||
is_hip,
|
|
||||||
rank0_log,
|
rank0_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
_is_hip = is_hip()
|
|
||||||
|
|
||||||
|
|
||||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||||
for sub in model._modules.values():
|
for sub in model._modules.values():
|
||||||
@@ -137,7 +134,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
gpu_mem = get_device_memory_capacity()
|
gpu_mem = get_device_memory_capacity()
|
||||||
# Batch size of each rank will not become so large when DP is on
|
|
||||||
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
||||||
capture_bs += list(range(160, 257, 8))
|
capture_bs += list(range(160, 257, 8))
|
||||||
|
|
||||||
@@ -148,12 +144,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|||||||
model_runner.req_to_token_pool.size
|
model_runner.req_to_token_pool.size
|
||||||
]
|
]
|
||||||
|
|
||||||
capture_bs = list(sorted(set(capture_bs)))
|
|
||||||
|
|
||||||
assert len(capture_bs) > 0 and capture_bs[0] > 0
|
|
||||||
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
|
||||||
if server_args.cuda_graph_max_bs:
|
if server_args.cuda_graph_max_bs:
|
||||||
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
||||||
|
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
||||||
|
capture_bs += list(
|
||||||
|
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
|
||||||
|
)
|
||||||
|
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
||||||
|
capture_bs = list(sorted(set(capture_bs)))
|
||||||
|
assert len(capture_bs) > 0 and capture_bs[0] > 0
|
||||||
compile_bs = (
|
compile_bs = (
|
||||||
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
||||||
if server_args.enable_torch_compile
|
if server_args.enable_torch_compile
|
||||||
|
|||||||
@@ -1085,32 +1085,33 @@ class ModelRunner:
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
skip_attn_backend_init: bool = False,
|
skip_attn_backend_init: bool = False,
|
||||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
||||||
can_run_cuda_graph = bool(
|
can_run_cuda_graph = bool(
|
||||||
forward_batch.forward_mode.is_cuda_graph()
|
forward_batch.forward_mode.is_cuda_graph()
|
||||||
and self.cuda_graph_runner
|
and self.cuda_graph_runner
|
||||||
and self.cuda_graph_runner.can_run(forward_batch)
|
and self.cuda_graph_runner.can_run(forward_batch)
|
||||||
)
|
)
|
||||||
if can_run_cuda_graph:
|
if can_run_cuda_graph:
|
||||||
return self.cuda_graph_runner.replay(
|
ret = self.cuda_graph_runner.replay(
|
||||||
forward_batch,
|
forward_batch,
|
||||||
skip_attn_backend_init=skip_attn_backend_init,
|
skip_attn_backend_init=skip_attn_backend_init,
|
||||||
pp_proxy_tensors=pp_proxy_tensors,
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
|
elif forward_batch.forward_mode.is_decode():
|
||||||
if forward_batch.forward_mode.is_decode():
|
ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
||||||
return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
|
||||||
elif forward_batch.forward_mode.is_extend():
|
elif forward_batch.forward_mode.is_extend():
|
||||||
return self.forward_extend(
|
ret = self.forward_extend(
|
||||||
forward_batch,
|
forward_batch,
|
||||||
skip_attn_backend_init=skip_attn_backend_init,
|
skip_attn_backend_init=skip_attn_backend_init,
|
||||||
pp_proxy_tensors=pp_proxy_tensors,
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
elif forward_batch.forward_mode.is_idle():
|
elif forward_batch.forward_mode.is_idle():
|
||||||
return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
||||||
|
|
||||||
|
return ret, can_run_cuda_graph
|
||||||
|
|
||||||
def _preprocess_logits(
|
def _preprocess_logits(
|
||||||
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1086,7 +1086,7 @@ class ServerArgs:
|
|||||||
"--cuda-graph-max-bs",
|
"--cuda-graph-max-bs",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.cuda_graph_max_bs,
|
default=ServerArgs.cuda_graph_max_bs,
|
||||||
help="Set the maximum batch size for cuda graph.",
|
help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cuda-graph-bs",
|
"--cuda-graph-bs",
|
||||||
|
|||||||
@@ -251,8 +251,8 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
spec_info = self.draft(batch)
|
spec_info = self.draft(batch)
|
||||||
logits_output, verify_output, model_worker_batch = self.verify(
|
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
||||||
batch, spec_info
|
self.verify(batch, spec_info)
|
||||||
)
|
)
|
||||||
|
|
||||||
# If it is None, it means all requests are finished
|
# If it is None, it means all requests are finished
|
||||||
@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
verify_output.verified_id,
|
verify_output.verified_id,
|
||||||
model_worker_batch.bid,
|
model_worker_batch.bid,
|
||||||
sum(verify_output.accept_length_per_req_cpu),
|
sum(verify_output.accept_length_per_req_cpu),
|
||||||
|
can_run_cuda_graph,
|
||||||
)
|
)
|
||||||
elif batch.forward_mode.is_idle():
|
elif batch.forward_mode.is_idle():
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
logits_output, next_token_ids, _ = (
|
||||||
model_worker_batch
|
self.target_worker.forward_batch_generation(model_worker_batch)
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits_output, next_token_ids, model_worker_batch.bid, 0
|
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
||||||
else:
|
else:
|
||||||
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
self.forward_draft_extend(
|
self.forward_draft_extend(
|
||||||
batch, logits_output.hidden_states, next_token_ids
|
batch, logits_output.hidden_states, next_token_ids
|
||||||
)
|
)
|
||||||
return logits_output, next_token_ids, bid, 0
|
return logits_output, next_token_ids, bid, 0, False
|
||||||
|
|
||||||
def forward_target_extend(
|
def forward_target_extend(
|
||||||
self, batch: ScheduleBatch
|
self, batch: ScheduleBatch
|
||||||
@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
return logits_output, next_token_ids, model_worker_batch.bid
|
return logits_output, next_token_ids, model_worker_batch.bid
|
||||||
@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
batch.spec_info = spec_info
|
batch.spec_info = spec_info
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
logits_output, _ = self.target_worker.forward_batch_generation(
|
logits_output, _, can_run_cuda_graph = (
|
||||||
model_worker_batch, skip_sample=True
|
self.target_worker.forward_batch_generation(
|
||||||
|
model_worker_batch, skip_sample=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self._detect_nan_if_needed(logits_output)
|
self._detect_nan_if_needed(logits_output)
|
||||||
spec_info.hidden_states = logits_output.hidden_states
|
spec_info.hidden_states = logits_output.hidden_states
|
||||||
@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
self.add_logprob_values(batch, res, logits_output)
|
self.add_logprob_values(batch, res, logits_output)
|
||||||
|
|
||||||
return logits_output, res, model_worker_batch
|
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
||||||
|
|
||||||
def add_logprob_values(
|
def add_logprob_values(
|
||||||
self,
|
self,
|
||||||
@@ -590,7 +593,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
forward_batch.return_logprob = False
|
forward_batch.return_logprob = False
|
||||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||||
self._detect_nan_if_needed(logits_output)
|
self._detect_nan_if_needed(logits_output)
|
||||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
assert forward_batch.spec_info is batch.spec_info
|
assert forward_batch.spec_info is batch.spec_info
|
||||||
@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run
|
# Run
|
||||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||||
|
|
||||||
self._detect_nan_if_needed(logits_output)
|
self._detect_nan_if_needed(logits_output)
|
||||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||||
|
|||||||
@@ -395,12 +395,12 @@ def popen_launch_server(
|
|||||||
other_args: list[str] = (),
|
other_args: list[str] = (),
|
||||||
env: Optional[dict] = None,
|
env: Optional[dict] = None,
|
||||||
return_stdout_stderr: Optional[tuple] = None,
|
return_stdout_stderr: Optional[tuple] = None,
|
||||||
pd_seperated: bool = False,
|
pd_separated: bool = False,
|
||||||
):
|
):
|
||||||
_, host, port = base_url.split(":")
|
_, host, port = base_url.split(":")
|
||||||
host = host[2:]
|
host = host[2:]
|
||||||
|
|
||||||
if pd_seperated:
|
if pd_separated:
|
||||||
command = "sglang.launch_pd_server"
|
command = "sglang.launch_pd_server"
|
||||||
else:
|
else:
|
||||||
command = "sglang.launch_server"
|
command = "sglang.launch_server"
|
||||||
@@ -414,7 +414,7 @@ def popen_launch_server(
|
|||||||
*[str(x) for x in other_args],
|
*[str(x) for x in other_args],
|
||||||
]
|
]
|
||||||
|
|
||||||
if pd_seperated:
|
if pd_separated:
|
||||||
command.extend(
|
command.extend(
|
||||||
[
|
[
|
||||||
"--lb-host",
|
"--lb-host",
|
||||||
@@ -656,7 +656,7 @@ def get_benchmark_args(
|
|||||||
disable_stream=False,
|
disable_stream=False,
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
pd_seperated: bool = False,
|
pd_separated: bool = False,
|
||||||
):
|
):
|
||||||
return SimpleNamespace(
|
return SimpleNamespace(
|
||||||
backend="sglang",
|
backend="sglang",
|
||||||
@@ -686,7 +686,7 @@ def get_benchmark_args(
|
|||||||
profile=None,
|
profile=None,
|
||||||
lora_name=None,
|
lora_name=None,
|
||||||
prompt_suffix="",
|
prompt_suffix="",
|
||||||
pd_seperated=pd_seperated,
|
pd_separated=pd_separated,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -750,7 +750,7 @@ def run_bench_serving_multi(
|
|||||||
other_server_args,
|
other_server_args,
|
||||||
benchmark_args,
|
benchmark_args,
|
||||||
need_warmup=False,
|
need_warmup=False,
|
||||||
pd_seperated=False,
|
pd_separated=False,
|
||||||
):
|
):
|
||||||
# Launch the server
|
# Launch the server
|
||||||
process = popen_launch_server(
|
process = popen_launch_server(
|
||||||
@@ -758,7 +758,7 @@ def run_bench_serving_multi(
|
|||||||
base_url,
|
base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=other_server_args,
|
other_args=other_server_args,
|
||||||
pd_seperated=pd_seperated,
|
pd_separated=pd_separated,
|
||||||
)
|
)
|
||||||
|
|
||||||
# run benchmark for all
|
# run benchmark for all
|
||||||
|
|||||||
@@ -101,8 +101,8 @@ suites = {
|
|||||||
# TestFile("test_deepep_intranode.py", 50),
|
# TestFile("test_deepep_intranode.py", 50),
|
||||||
# TestFile("test_deepep_low_latency.py", 50),
|
# TestFile("test_deepep_low_latency.py", 50),
|
||||||
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
|
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
|
||||||
|
# TestFile("test_disaggregation.py", 90),
|
||||||
TestFile("test_local_attn.py", 250),
|
TestFile("test_local_attn.py", 250),
|
||||||
TestFile("test_disaggregation.py", 90),
|
|
||||||
TestFile("test_full_deepseek_v3.py", 250),
|
TestFile("test_full_deepseek_v3.py", 250),
|
||||||
TestFile("test_pp_single_node.py", 150),
|
TestFile("test_pp_single_node.py", 150),
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -97,7 +97,9 @@ class TestEAGLEEngine(CustomTestCase):
|
|||||||
|
|
||||||
print(f"{engine.get_server_info()=}")
|
print(f"{engine.get_server_info()=}")
|
||||||
|
|
||||||
avg_spec_accept_length = engine.get_server_info()["avg_spec_accept_length"]
|
avg_spec_accept_length = engine.get_server_info()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 1.9)
|
self.assertGreater(avg_spec_accept_length, 1.9)
|
||||||
|
|
||||||
@@ -296,7 +298,9 @@ class TestEAGLEServer(CustomTestCase):
|
|||||||
self.assertGreater(metrics["accuracy"], 0.20)
|
self.assertGreater(metrics["accuracy"], 0.20)
|
||||||
|
|
||||||
server_info = requests.get(self.base_url + "/get_server_info").json()
|
server_info = requests.get(self.base_url + "/get_server_info").json()
|
||||||
avg_spec_accept_length = server_info["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
|
|
||||||
speculative_eagle_topk = server_info["speculative_eagle_topk"]
|
speculative_eagle_topk = server_info["speculative_eagle_topk"]
|
||||||
|
|||||||
@@ -111,7 +111,9 @@ class BaseFlashAttentionTest(CustomTestCase):
|
|||||||
|
|
||||||
if self.speculative_decode:
|
if self.speculative_decode:
|
||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
|
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
|
||||||
|
|
||||||
|
|||||||
@@ -118,7 +118,9 @@ class TestDeepseekV3MTP(CustomTestCase):
|
|||||||
print(f"{metrics=}")
|
print(f"{metrics=}")
|
||||||
|
|
||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
|
|
||||||
if is_in_ci():
|
if is_in_ci():
|
||||||
|
|||||||
@@ -100,7 +100,9 @@ class TestDeepseekV3MTP(CustomTestCase):
|
|||||||
self.assertGreater(metrics["accuracy"], 0.60)
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|
||||||
@@ -159,7 +161,9 @@ class TestDeepseekV3MTPWithDraft(CustomTestCase):
|
|||||||
self.assertGreater(metrics["accuracy"], 0.60)
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,9 @@ class TestFlashinferMLAMTP(CustomTestCase):
|
|||||||
|
|
||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
print(f"{server_info=}")
|
print(f"{server_info=}")
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|
||||||
|
|||||||
@@ -105,7 +105,9 @@ class TestDeepseekV3MTPChannelInt8(CustomTestCase):
|
|||||||
self.assertGreater(metrics["accuracy"], 0.60)
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|
||||||
@@ -199,7 +201,9 @@ class TestDeepseekV3MTPBlockInt8(CustomTestCase):
|
|||||||
self.assertGreater(metrics["accuracy"], 0.60)
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
server_info = requests.get(self.base_url + "/get_server_info")
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
print(f"{avg_spec_accept_length=}")
|
print(f"{avg_spec_accept_length=}")
|
||||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|
||||||
|
|||||||
@@ -492,9 +492,6 @@ class TestSRTEndpoint(CustomTestCase):
|
|||||||
max_total_num_tokens = response_json["max_total_num_tokens"]
|
max_total_num_tokens = response_json["max_total_num_tokens"]
|
||||||
self.assertIsInstance(max_total_num_tokens, int)
|
self.assertIsInstance(max_total_num_tokens, int)
|
||||||
|
|
||||||
attention_backend = response_json["attention_backend"]
|
|
||||||
self.assertIsInstance(attention_backend, str)
|
|
||||||
|
|
||||||
version = response_json["version"]
|
version = response_json["version"]
|
||||||
self.assertIsInstance(version, str)
|
self.assertIsInstance(version, str)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user