Save peak memory in logits processor (#8343)
This commit is contained in:
@@ -170,8 +170,6 @@ class LogitsMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def compute_dp_attention_metadata(self):
|
def compute_dp_attention_metadata(self):
|
||||||
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
|
|
||||||
# we may use a smaller buffer in draft extend.
|
|
||||||
|
|
||||||
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
||||||
dp_rank = get_attention_dp_rank()
|
dp_rank = get_attention_dp_rank()
|
||||||
@@ -186,6 +184,19 @@ class LogitsMetadata:
|
|||||||
self.dp_local_start_pos = dp_local_start_pos
|
self.dp_local_start_pos = dp_local_start_pos
|
||||||
self.dp_local_num_tokens = dp_local_num_tokens
|
self.dp_local_num_tokens = dp_local_num_tokens
|
||||||
|
|
||||||
|
if self.global_num_tokens_for_logprob_cpu is not None:
|
||||||
|
# create a smaller buffer to reduce peak memory usage
|
||||||
|
self.gathered_buffer = torch.empty(
|
||||||
|
(
|
||||||
|
sum(self.global_num_tokens_for_logprob_cpu),
|
||||||
|
self.gathered_buffer.shape[1],
|
||||||
|
),
|
||||||
|
dtype=self.gathered_buffer.dtype,
|
||||||
|
device=self.gathered_buffer.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.gathered_buffer = torch.empty_like(self.gathered_buffer)
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor(nn.Module):
|
class LogitsProcessor(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -430,7 +441,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
if self.do_tensor_parallel_all_gather_dp_attn:
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
||||||
logits_metadata.compute_dp_attention_metadata()
|
logits_metadata.compute_dp_attention_metadata()
|
||||||
hidden_states, local_hidden_states = (
|
hidden_states, local_hidden_states = (
|
||||||
torch.empty_like(logits_metadata.gathered_buffer),
|
logits_metadata.gathered_buffer,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
)
|
)
|
||||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||||
|
|||||||
Reference in New Issue
Block a user