[feature] Ascend NPU graph support (#9399)
Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: yezhifeng (D) <y00897525@china.huawei.com> Co-authored-by: anon189Ty <Stari_Falcon@outlook.com> Co-authored-by: Maksim <makcum888e@mail.ru> Co-authored-by: ssshinigami <44640852+ssshinigami@users.noreply.github.com>
This commit is contained in:
@@ -55,7 +55,7 @@ _is_npu = is_npu()
|
||||
|
||||
@dataclass
|
||||
class GraphCaptureContext:
|
||||
stream: torch.cuda.Stream
|
||||
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
|
||||
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
@@ -252,8 +252,11 @@ class GroupCoordinator:
|
||||
|
||||
if is_cuda_alike():
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
elif _is_npu:
|
||||
self.device = torch.device(f"npu:{local_rank}")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
self.device_module = torch.get_device_module(self.device)
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_pymscclpp = use_pymscclpp
|
||||
@@ -402,7 +405,7 @@ class GroupCoordinator:
|
||||
self, graph_capture_context: Optional[GraphCaptureContext] = None
|
||||
):
|
||||
if graph_capture_context is None:
|
||||
stream = torch.cuda.Stream()
|
||||
stream = self.device_module.Stream()
|
||||
graph_capture_context = GraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
@@ -413,11 +416,11 @@ class GroupCoordinator:
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
curr_stream = torch.cuda.current_stream()
|
||||
curr_stream = self.device_module.current_stream()
|
||||
if curr_stream != stream:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch.cuda.stream(stream), maybe_ca_context:
|
||||
with self.device_module.stream(stream), maybe_ca_context:
|
||||
# In graph mode, we have to be very careful about the collective
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
@@ -1641,6 +1644,8 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
)
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif hasattr(torch, "npu") and torch.npu.is_available():
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
||||
|
||||
Reference in New Issue
Block a user