Cleanup unused resources after DeepEP operation (#4996)

This commit is contained in:
fzyzcjy
2025-04-04 15:36:04 +08:00
committed by GitHub
parent 77e929a1a2
commit 6ff9c6a5e7

View File

@@ -184,11 +184,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
)
# TODO
# masked_m = torch.empty(
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
# )
# expected_m = 0
masked_m = expected_m = None
return (
@@ -327,6 +322,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def combine_b(self, output, previous_event):
hidden_states, event = self._combine_core(output, previous_event)
event.current_stream_wait() if self.async_finish else ()
self.handle = None
self.src2dst = None
return hidden_states
def _combine_core(self, x: torch.Tensor, previous_event):
@@ -402,13 +399,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
):
hook() if self.return_recv_hook else event.current_stream_wait()
# TODO
# reorder_topk_ids = torch.empty(
# (0,), device=hidden_states.device, dtype=torch.int64
# )
# seg_indptr = torch.zeros(
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
# )
reorder_topk_ids = seg_indptr = None
return (
@@ -508,6 +498,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
return_recv_hook=self.return_recv_hook,
)
)
self.handle = None
return combined_hidden_states, event, hook