Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-05-12 00:17:33 -07:00
committed by GitHub
parent 7d3a3d4510
commit fba8eccd7e
27 changed files with 293 additions and 121 deletions

View File

@@ -160,6 +160,7 @@ class GenerationBatchResult:
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
bid: int
can_run_cuda_graph: bool
@dataclass
@@ -323,13 +324,14 @@ class Scheduler(
set_random_seed(self.random_seed)
# Print debug info
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}"
)
if tp_rank == 0:
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}"
)
# Init memory pool and cache
self.init_memory_pool_and_cache()
@@ -752,6 +754,7 @@ class Scheduler(
extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result(mbs[next_mb_id], output_result)
last_mbs[next_mb_id] = mbs[next_mb_id]
@@ -1159,7 +1162,9 @@ class Scheduler(
self.metrics_collector.log_stats(self.stats)
def log_decode_stats(self, running_batch=None):
def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
):
batch = running_batch or self.running_batch
gap_latency = time.time() - self.last_decode_stats_tic
@@ -1199,6 +1204,7 @@ class Scheduler(
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
@@ -1524,11 +1530,11 @@ class Scheduler(
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
if self.pp_group.is_last_rank:
logits_output, next_token_ids = (
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
pp_hidden_states_proxy_tensors, _ = (
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
bid = model_worker_batch.bid
@@ -1538,6 +1544,7 @@ class Scheduler(
next_token_ids,
bid,
num_accepted_tokens,
can_run_cuda_graph,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
@@ -1571,6 +1578,7 @@ class Scheduler(
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=bid,
can_run_cuda_graph=can_run_cuda_graph,
)
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()

View File

@@ -38,20 +38,16 @@ class SchedulerOutputProcessorMixin:
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
bid,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
result.bid,
)
if self.enable_overlap:
logits_output, next_token_ids = (
self.tp_worker.resolve_last_batch_result(
launch_done,
)
logits_output, next_token_ids, _ = (
self.tp_worker.resolve_last_batch_result(launch_done)
)
else:
# Move next_token_ids and logprobs to cpu
@@ -189,16 +185,16 @@ class SchedulerOutputProcessorMixin:
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
):
logits_output, next_token_ids, bid = (
logits_output, next_token_ids, can_run_cuda_graph = (
result.logits_output,
result.next_token_ids,
result.bid,
result.can_run_cuda_graph,
)
self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
launch_done
logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.resolve_last_batch_result(launch_done)
)
next_token_logprobs = logits_output.next_token_logprobs
elif batch.spec_algorithm.is_none():
@@ -280,7 +276,7 @@ class SchedulerOutputProcessorMixin:
self.attn_tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.log_decode_stats(running_batch=batch)
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
def add_input_logprob_return_values(
self: Scheduler,

View File

@@ -923,12 +923,13 @@ class TokenizerManager:
):
await self.send_to_scheduler.send_pyobj(obj)
async def get_internal_state(self) -> Dict[Any, Any]:
async def get_internal_state(self) -> List[Dict[Any, Any]]:
req = GetInternalStateReq()
res: List[GetInternalStateReqOutput] = (
responses: List[GetInternalStateReqOutput] = (
await self.get_internal_state_communicator(req)
)
return res[0].internal_state
# Many DP ranks
return [res.internal_state for res in responses]
def get_log_request_metadata(self):
max_length = None

View File

@@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union
import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.hf_transformers_utils import (
get_processor,
get_tokenizer,
@@ -183,8 +183,11 @@ class TpModelWorker:
def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False,
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
) -> Tuple[
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
pp_proxy_tensors = None
@@ -196,11 +199,11 @@ class TpModelWorker:
)
if self.pp_group.is_last_rank:
logits_output = self.model_runner.forward(
logits_output, can_run_cuda_graph = self.model_runner.forward(
forward_batch, pp_proxy_tensors=pp_proxy_tensors
)
if model_worker_batch.launch_done is not None:
model_worker_batch.launch_done.set()
if launch_done is not None:
launch_done.set()
if skip_sample:
next_token_ids = None
@@ -209,17 +212,17 @@ class TpModelWorker:
logits_output, model_worker_batch
)
return logits_output, next_token_ids
return logits_output, next_token_ids, can_run_cuda_graph
else:
pp_proxy_tensors = self.model_runner.forward(
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
forward_batch,
pp_proxy_tensors=pp_proxy_tensors,
)
return pp_proxy_tensors.tensors, None
return pp_proxy_tensors.tensors, None, can_run_cuda_graph
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
logits_output, _ = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings

View File

@@ -18,7 +18,7 @@ import logging
import signal
import threading
from queue import Queue
from typing import Optional
from typing import Optional, Tuple
import psutil
import torch
@@ -145,8 +145,10 @@ class TpModelWorkerClient:
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch
logits_output, next_token_ids, can_run_cuda_graph = (
self.worker.forward_batch_generation(
model_worker_batch, model_worker_batch.launch_done
)
)
# Update the future token ids map
@@ -171,14 +173,18 @@ class TpModelWorkerClient:
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record()
self.output_queue.put((copy_done, logits_output, next_token_ids))
self.output_queue.put(
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
)
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done, logits_output, next_token_ids = self.output_queue.get()
copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
self.output_queue.get()
)
if launch_done is not None:
launch_done.wait()
@@ -193,9 +199,11 @@ class TpModelWorkerClient:
logits_output.input_token_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids
return logits_output, next_token_ids, can_run_cuda_graph
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
def forward_batch_generation(
self, model_worker_batch: ModelWorkerBatch
) -> Tuple[None, torch.Tensor, bool]:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info = model_worker_batch.sampling_info
sampling_info.update_penalties()
@@ -223,7 +231,7 @@ class TpModelWorkerClient:
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
return None, future_next_token_ids
return None, future_next_token_ids, False
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights_from_disk(recv_req)