diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 0aee86f68..90f981c57 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -170,8 +170,6 @@ class LogitsMetadata: ) def compute_dp_attention_metadata(self): - # TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend, - # we may use a smaller buffer in draft extend. cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) dp_rank = get_attention_dp_rank() @@ -186,6 +184,19 @@ class LogitsMetadata: self.dp_local_start_pos = dp_local_start_pos self.dp_local_num_tokens = dp_local_num_tokens + if self.global_num_tokens_for_logprob_cpu is not None: + # create a smaller buffer to reduce peak memory usage + self.gathered_buffer = torch.empty( + ( + sum(self.global_num_tokens_for_logprob_cpu), + self.gathered_buffer.shape[1], + ), + dtype=self.gathered_buffer.dtype, + device=self.gathered_buffer.device, + ) + else: + self.gathered_buffer = torch.empty_like(self.gathered_buffer) + class LogitsProcessor(nn.Module): def __init__( @@ -430,7 +441,7 @@ class LogitsProcessor(nn.Module): if self.do_tensor_parallel_all_gather_dp_attn: logits_metadata.compute_dp_attention_metadata() hidden_states, local_hidden_states = ( - torch.empty_like(logits_metadata.gathered_buffer), + logits_metadata.gathered_buffer, hidden_states, ) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)