[CUDA Graph] save cuda graph memory by using next_token_logits_buffer (#8579)
This commit is contained in:
@@ -142,6 +142,22 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.global_num_tokens_for_logprob_gpu = None
|
||||
self.gathered_buffer = None
|
||||
|
||||
if hasattr(
|
||||
self.model_runner.model_config.hf_config, "draft_vocab_size"
|
||||
): # llama_eagle
|
||||
vocab_size = self.model_runner.model_config.hf_config.draft_vocab_size
|
||||
elif hasattr(
|
||||
self.model_runner.model_config.hf_config, "hot_vocab_size"
|
||||
): # llama_eagle3
|
||||
vocab_size = self.model_runner.model_config.hf_config.hot_vocab_size
|
||||
else:
|
||||
vocab_size = self.model_runner.model_config.vocab_size
|
||||
|
||||
self.next_token_logits_buffer = torch.zeros(
|
||||
(self.max_bs, vocab_size),
|
||||
dtype=torch.float,
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
@@ -189,6 +205,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
hidden_states = self.hidden_states[:num_tokens]
|
||||
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
|
||||
|
||||
if self.require_mlp_tp_gather:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
@@ -238,6 +255,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
input_ids=input_ids,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
next_token_logits_buffer=next_token_logits_buffer,
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
out_cache_loc=out_cache_loc,
|
||||
|
||||
Reference in New Issue
Block a user