[CUDA Graph] save cuda graph memory by using next_token_logits_buffer (#8579)

This commit is contained in:
Cheng Wan
2025-08-03 03:06:47 -07:00
committed by GitHub
parent 7a91330149
commit cb099d2095
5 changed files with 36 additions and 1 deletions

View File

@@ -375,6 +375,11 @@ class CudaGraphRunner:
dtype=torch.bool,
device="cuda",
)
self.next_token_logits_buffer = torch.zeros(
(self.max_num_token, self.model_runner.model_config.vocab_size),
dtype=torch.float,
device="cuda",
)
# Capture
try:
@@ -520,6 +525,7 @@ class CudaGraphRunner:
else:
encoder_lens = None
mrope_positions = self.mrope_positions[:, :bs]
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
self.num_token_non_padded[...] = num_tokens
# pipeline parallelism
@@ -582,6 +588,7 @@ class CudaGraphRunner:
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
next_token_logits_buffer=next_token_logits_buffer,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,

View File

@@ -189,6 +189,7 @@ class ForwardBatch:
token_ids_logprobs: Optional[List[List[int]]] = None
# For logits and logprobs post processing
next_token_logits_buffer: torch.Tensor = None
temp_scaled_logprobs: bool = False
temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False