Save peak memory in logits processor (#8343)
This commit is contained in:
@@ -170,8 +170,6 @@ class LogitsMetadata:
|
||||
)
|
||||
|
||||
def compute_dp_attention_metadata(self):
|
||||
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
|
||||
# we may use a smaller buffer in draft extend.
|
||||
|
||||
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
||||
dp_rank = get_attention_dp_rank()
|
||||
@@ -186,6 +184,19 @@ class LogitsMetadata:
|
||||
self.dp_local_start_pos = dp_local_start_pos
|
||||
self.dp_local_num_tokens = dp_local_num_tokens
|
||||
|
||||
if self.global_num_tokens_for_logprob_cpu is not None:
|
||||
# create a smaller buffer to reduce peak memory usage
|
||||
self.gathered_buffer = torch.empty(
|
||||
(
|
||||
sum(self.global_num_tokens_for_logprob_cpu),
|
||||
self.gathered_buffer.shape[1],
|
||||
),
|
||||
dtype=self.gathered_buffer.dtype,
|
||||
device=self.gathered_buffer.device,
|
||||
)
|
||||
else:
|
||||
self.gathered_buffer = torch.empty_like(self.gathered_buffer)
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
def __init__(
|
||||
@@ -430,7 +441,7 @@ class LogitsProcessor(nn.Module):
|
||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||
logits_metadata.compute_dp_attention_metadata()
|
||||
hidden_states, local_hidden_states = (
|
||||
torch.empty_like(logits_metadata.gathered_buffer),
|
||||
logits_metadata.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||
|
||||
Reference in New Issue
Block a user