Cleanup unused resources after DeepEP operation (#4996)
This commit is contained in:
@@ -184,11 +184,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO
|
|
||||||
# masked_m = torch.empty(
|
|
||||||
# (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
# )
|
|
||||||
# expected_m = 0
|
|
||||||
masked_m = expected_m = None
|
masked_m = expected_m = None
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -327,6 +322,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
def combine_b(self, output, previous_event):
|
def combine_b(self, output, previous_event):
|
||||||
hidden_states, event = self._combine_core(output, previous_event)
|
hidden_states, event = self._combine_core(output, previous_event)
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
|
self.handle = None
|
||||||
|
self.src2dst = None
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def _combine_core(self, x: torch.Tensor, previous_event):
|
def _combine_core(self, x: torch.Tensor, previous_event):
|
||||||
@@ -402,13 +399,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
):
|
):
|
||||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||||
|
|
||||||
# TODO
|
|
||||||
# reorder_topk_ids = torch.empty(
|
|
||||||
# (0,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
# )
|
|
||||||
# seg_indptr = torch.zeros(
|
|
||||||
# (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
# )
|
|
||||||
reorder_topk_ids = seg_indptr = None
|
reorder_topk_ids = seg_indptr = None
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@@ -508,6 +498,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
return_recv_hook=self.return_recv_hook,
|
return_recv_hook=self.return_recv_hook,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.handle = None
|
||||||
return combined_hidden_states, event, hook
|
return combined_hidden_states, event, hook
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user