[CUDA Graph] save cuda graph memory by using next_token_logits_buffer (#8579)
This commit is contained in:
@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
|
|||||||
class LogitsMetadata:
|
class LogitsMetadata:
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
||||||
|
next_token_logits_buffer: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
extend_return_logprob: bool = False
|
extend_return_logprob: bool = False
|
||||||
extend_return_top_logprob: bool = False
|
extend_return_top_logprob: bool = False
|
||||||
@@ -148,6 +149,7 @@ class LogitsMetadata:
|
|||||||
return cls(
|
return cls(
|
||||||
forward_mode=forward_batch.forward_mode,
|
forward_mode=forward_batch.forward_mode,
|
||||||
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
||||||
|
next_token_logits_buffer=forward_batch.next_token_logits_buffer,
|
||||||
extend_return_logprob=extend_return_logprob,
|
extend_return_logprob=extend_return_logprob,
|
||||||
extend_return_top_logprob=extend_return_top_logprob,
|
extend_return_top_logprob=extend_return_top_logprob,
|
||||||
extend_token_ids_logprob=extend_token_ids_logprob,
|
extend_token_ids_logprob=extend_token_ids_logprob,
|
||||||
@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
|
|||||||
)
|
)
|
||||||
dp_scatter(logits, global_logits, logits_metadata)
|
dp_scatter(logits, global_logits, logits_metadata)
|
||||||
|
|
||||||
logits = logits[:, : self.config.vocab_size].float()
|
if logits_metadata.next_token_logits_buffer is not None:
|
||||||
|
logits_buffer = logits_metadata.next_token_logits_buffer
|
||||||
|
assert logits_buffer.dtype == torch.float
|
||||||
|
logits_buffer.copy_(logits[:, : self.config.vocab_size])
|
||||||
|
logits = logits_buffer
|
||||||
|
else:
|
||||||
|
logits = logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
if self.final_logit_softcapping:
|
if self.final_logit_softcapping:
|
||||||
fused_softcap(logits, self.final_logit_softcapping)
|
fused_softcap(logits, self.final_logit_softcapping)
|
||||||
|
|||||||
@@ -375,6 +375,11 @@ class CudaGraphRunner:
|
|||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
self.next_token_logits_buffer = torch.zeros(
|
||||||
|
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
@@ -520,6 +525,7 @@ class CudaGraphRunner:
|
|||||||
else:
|
else:
|
||||||
encoder_lens = None
|
encoder_lens = None
|
||||||
mrope_positions = self.mrope_positions[:, :bs]
|
mrope_positions = self.mrope_positions[:, :bs]
|
||||||
|
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
||||||
self.num_token_non_padded[...] = num_tokens
|
self.num_token_non_padded[...] = num_tokens
|
||||||
|
|
||||||
# pipeline parallelism
|
# pipeline parallelism
|
||||||
@@ -582,6 +588,7 @@ class CudaGraphRunner:
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
next_token_logits_buffer=next_token_logits_buffer,
|
||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
attn_backend=self.model_runner.attn_backend,
|
attn_backend=self.model_runner.attn_backend,
|
||||||
|
|||||||
@@ -189,6 +189,7 @@ class ForwardBatch:
|
|||||||
token_ids_logprobs: Optional[List[List[int]]] = None
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
||||||
|
|
||||||
# For logits and logprobs post processing
|
# For logits and logprobs post processing
|
||||||
|
next_token_logits_buffer: torch.Tensor = None
|
||||||
temp_scaled_logprobs: bool = False
|
temp_scaled_logprobs: bool = False
|
||||||
temperature: torch.Tensor = None
|
temperature: torch.Tensor = None
|
||||||
top_p_normalized_logprobs: bool = False
|
top_p_normalized_logprobs: bool = False
|
||||||
|
|||||||
@@ -142,6 +142,22 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.global_num_tokens_for_logprob_gpu = None
|
self.global_num_tokens_for_logprob_gpu = None
|
||||||
self.gathered_buffer = None
|
self.gathered_buffer = None
|
||||||
|
|
||||||
|
if hasattr(
|
||||||
|
self.model_runner.model_config.hf_config, "draft_vocab_size"
|
||||||
|
): # llama_eagle
|
||||||
|
vocab_size = self.model_runner.model_config.hf_config.draft_vocab_size
|
||||||
|
elif hasattr(
|
||||||
|
self.model_runner.model_config.hf_config, "hot_vocab_size"
|
||||||
|
): # llama_eagle3
|
||||||
|
vocab_size = self.model_runner.model_config.hf_config.hot_vocab_size
|
||||||
|
else:
|
||||||
|
vocab_size = self.model_runner.model_config.vocab_size
|
||||||
|
|
||||||
|
self.next_token_logits_buffer = torch.zeros(
|
||||||
|
(self.max_bs, vocab_size),
|
||||||
|
dtype=torch.float,
|
||||||
|
)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
with model_capture_mode():
|
with model_capture_mode():
|
||||||
@@ -189,6 +205,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||||
positions = self.positions[:num_tokens]
|
positions = self.positions[:num_tokens]
|
||||||
hidden_states = self.hidden_states[:num_tokens]
|
hidden_states = self.hidden_states[:num_tokens]
|
||||||
|
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
|
||||||
|
|
||||||
if self.require_mlp_tp_gather:
|
if self.require_mlp_tp_gather:
|
||||||
self.global_num_tokens_gpu.copy_(
|
self.global_num_tokens_gpu.copy_(
|
||||||
@@ -238,6 +255,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
|
next_token_logits_buffer=next_token_logits_buffer,
|
||||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
|
|||||||
@@ -564,6 +564,7 @@ class TboForwardBatchPreparer:
|
|||||||
mm_inputs=None,
|
mm_inputs=None,
|
||||||
top_logprobs_nums=None,
|
top_logprobs_nums=None,
|
||||||
token_ids_logprobs=None,
|
token_ids_logprobs=None,
|
||||||
|
next_token_logits_buffer=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user