Cleanup unused resources after DeepEP operation (#4996)
This commit is contained in:
@@ -184,11 +184,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
|
||||
# TODO
|
||||
# masked_m = torch.empty(
|
||||
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
|
||||
# )
|
||||
# expected_m = 0
|
||||
masked_m = expected_m = None
|
||||
|
||||
return (
|
||||
@@ -327,6 +322,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
def combine_b(self, output, previous_event):
|
||||
hidden_states, event = self._combine_core(output, previous_event)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
self.handle = None
|
||||
self.src2dst = None
|
||||
return hidden_states
|
||||
|
||||
def _combine_core(self, x: torch.Tensor, previous_event):
|
||||
@@ -402,13 +399,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
):
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
|
||||
# TODO
|
||||
# reorder_topk_ids = torch.empty(
|
||||
# (0,), device=hidden_states.device, dtype=torch.int64
|
||||
# )
|
||||
# seg_indptr = torch.zeros(
|
||||
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||
# )
|
||||
reorder_topk_ids = seg_indptr = None
|
||||
|
||||
return (
|
||||
@@ -508,6 +498,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
)
|
||||
)
|
||||
self.handle = None
|
||||
return combined_hidden_states, event, hook
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user