Revert "[feature] Ascend NPU graph support (#8027)" (#9348)

This commit is contained in:
Even Zhou
2025-08-20 01:11:23 +08:00
committed by GitHub
parent 01d47a27b6
commit f4fafacc5d
18 changed files with 878 additions and 1349 deletions

View File

@@ -55,7 +55,7 @@ _is_npu = is_npu()
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
stream: torch.cuda.Stream
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
@@ -252,13 +252,9 @@ class GroupCoordinator:
if is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif _is_npu:
self.device = torch.device(f"npu:{local_rank}")
else:
self.device = torch.device("cpu")
self.device_module = torch.get_device_module(self.device)
self.use_pynccl = use_pynccl
self.use_pymscclpp = use_pymscclpp
self.use_custom_allreduce = use_custom_allreduce
@@ -406,7 +402,7 @@ class GroupCoordinator:
self, graph_capture_context: Optional[GraphCaptureContext] = None
):
if graph_capture_context is None:
stream = self.device_module.Stream()
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
@@ -417,11 +413,11 @@ class GroupCoordinator:
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = self.device_module.current_stream()
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with self.device_module.stream(stream), maybe_ca_context:
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
@@ -1645,8 +1641,6 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
)
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: