[feature] Ascend NPU graph support (#9399)

Co-authored-by: ronnie_zheng <zl19940307@163.com>
Co-authored-by: yezhifeng (D) <y00897525@china.huawei.com>
Co-authored-by: anon189Ty <Stari_Falcon@outlook.com>
Co-authored-by: Maksim <makcum888e@mail.ru>
Co-authored-by: ssshinigami <44640852+ssshinigami@users.noreply.github.com>
This commit is contained in:
VDV1985
2025-08-21 07:13:27 +03:00
committed by GitHub
parent 7cd2ee06d7
commit 2c4b4b786b
9 changed files with 470 additions and 48 deletions

View File

@@ -55,7 +55,7 @@ _is_npu = is_npu()
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
@@ -252,8 +252,11 @@ class GroupCoordinator:
if is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif _is_npu:
self.device = torch.device(f"npu:{local_rank}")
else:
self.device = torch.device("cpu")
self.device_module = torch.get_device_module(self.device)
self.use_pynccl = use_pynccl
self.use_pymscclpp = use_pymscclpp
@@ -402,7 +405,7 @@ class GroupCoordinator:
self, graph_capture_context: Optional[GraphCaptureContext] = None
):
if graph_capture_context is None:
stream = torch.cuda.Stream()
stream = self.device_module.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
@@ -413,11 +416,11 @@ class GroupCoordinator:
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
curr_stream = self.device_module.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.cuda.stream(stream), maybe_ca_context:
with self.device_module.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
@@ -1641,6 +1644,8 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
)
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: