[CUDA Graph] save cuda graph memory by using next_token_logits_buffer (#8579)

This commit is contained in:
Cheng Wan
2025-08-03 03:06:47 -07:00
committed by GitHub
parent 7a91330149
commit cb099d2095
5 changed files with 36 additions and 1 deletions

View File

@@ -142,6 +142,22 @@ class EAGLEDraftExtendCudaGraphRunner:
self.global_num_tokens_for_logprob_gpu = None
self.gathered_buffer = None
if hasattr(
self.model_runner.model_config.hf_config, "draft_vocab_size"
): # llama_eagle
vocab_size = self.model_runner.model_config.hf_config.draft_vocab_size
elif hasattr(
self.model_runner.model_config.hf_config, "hot_vocab_size"
): # llama_eagle3
vocab_size = self.model_runner.model_config.hf_config.hot_vocab_size
else:
vocab_size = self.model_runner.model_config.vocab_size
self.next_token_logits_buffer = torch.zeros(
(self.max_bs, vocab_size),
dtype=torch.float,
)
# Capture
try:
with model_capture_mode():
@@ -189,6 +205,7 @@ class EAGLEDraftExtendCudaGraphRunner:
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens]
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
@@ -238,6 +255,7 @@ class EAGLEDraftExtendCudaGraphRunner:
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
next_token_logits_buffer=next_token_logits_buffer,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,