[Platform] format platform to make it more clear (#610)
Platform should only contain the function that based from vllm. This PR move the unrelated function to the right place to make platform more clear. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -62,7 +62,7 @@
|
|||||||
# inside of the repo, and needs a common interface to destroy them, this patch add the interface of destroy
|
# inside of the repo, and needs a common interface to destroy them, this patch add the interface of destroy
|
||||||
# platform owned `CoordinatorGroup` to make sure all the CoordinateGroup can be properly destroyed
|
# platform owned `CoordinatorGroup` to make sure all the CoordinateGroup can be properly destroyed
|
||||||
# How:
|
# How:
|
||||||
# Call platform method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup`
|
# Call `vllm_ascend.distributed.parallel_state method `destroy_platform_model_parallel` to destroy all the `CoordinateGroup`
|
||||||
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
|
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove those patch when vllm merged them
|
# Remove those patch when vllm merged them
|
||||||
@@ -73,7 +73,7 @@
|
|||||||
# call to the `stateless_init_torch_distributed_process_group`, to enable other platform which may support
|
# call to the `stateless_init_torch_distributed_process_group`, to enable other platform which may support
|
||||||
# stateless process group initialize method
|
# stateless process group initialize method
|
||||||
# How:
|
# How:
|
||||||
# Call platform method `platform_has_backend_register` to judge if there is a stateless process group initialize
|
# rewrite stateless_init_torch_distributed_process_group to judge if there is a stateless process group initialize
|
||||||
# method and call platform method `platform_register_backend` to initialize them
|
# method and call platform method `platform_register_backend` to initialize them
|
||||||
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
|
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
|
|||||||
@@ -42,8 +42,9 @@ def ascend_destroy_model_parallel():
|
|||||||
if _DP:
|
if _DP:
|
||||||
_DP.destroy()
|
_DP.destroy()
|
||||||
_DP = None
|
_DP = None
|
||||||
from vllm.platforms import current_platform
|
from vllm_ascend.distributed.parallel_state import \
|
||||||
current_platform.destroy_platform_model_parallel()
|
destory_ascend_model_parallel
|
||||||
|
destory_ascend_model_parallel()
|
||||||
|
|
||||||
|
|
||||||
def ascend_stateless_init_torch_distributed_process_group(
|
def ascend_stateless_init_torch_distributed_process_group(
|
||||||
@@ -100,7 +101,6 @@ def ascend_stateless_init_torch_distributed_process_group(
|
|||||||
group_rank,
|
group_rank,
|
||||||
group_size,
|
group_size,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
|
||||||
if backend == "gloo":
|
if backend == "gloo":
|
||||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||||
backend_class = ProcessGroupGloo(prefix_store,
|
backend_class = ProcessGroupGloo(prefix_store,
|
||||||
@@ -120,8 +120,18 @@ def ascend_stateless_init_torch_distributed_process_group(
|
|||||||
backend_options)
|
backend_options)
|
||||||
backend_type = ProcessGroup.BackendType.NCCL
|
backend_type = ProcessGroup.BackendType.NCCL
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
elif current_platform.platform_has_backend_register():
|
elif backend == "hccl":
|
||||||
current_platform.platform_register_backend()
|
from torch.distributed import is_hccl_available
|
||||||
|
assert is_hccl_available()
|
||||||
|
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||||
|
backend_options = ProcessGroupHCCL.Options()
|
||||||
|
backend_options._timeout = timeout
|
||||||
|
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
||||||
|
backend_options)
|
||||||
|
device = torch.device("npu")
|
||||||
|
backend_class._set_sequence_number_for_group()
|
||||||
|
backend_type = ProcessGroup.BackendType.CUSTOM
|
||||||
|
pg._register_backend(device, backend_type, backend_class)
|
||||||
return pg
|
return pg
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
||||||
|
|||||||
@@ -226,29 +226,3 @@ class NPUPlatform(Platform):
|
|||||||
model configuration.
|
model configuration.
|
||||||
"""
|
"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def destroy_platform_model_parallel(cls) -> None:
|
|
||||||
from vllm_ascend.distributed.parallel_state import \
|
|
||||||
destory_ascend_model_parallel
|
|
||||||
destory_ascend_model_parallel()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def platform_has_backend_register(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def platform_register_backend(cls, pg, prefix_store, group_rank,
|
|
||||||
group_size, backend_options,
|
|
||||||
timeout) -> None:
|
|
||||||
from torch.distributed import ProcessGroup, is_hccl_available
|
|
||||||
assert is_hccl_available()
|
|
||||||
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
|
||||||
backend_options = ProcessGroupHCCL.Options()
|
|
||||||
backend_options._timeout = timeout
|
|
||||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
|
||||||
backend_options)
|
|
||||||
device = torch.device("npu")
|
|
||||||
backend_class._set_sequence_number_for_group()
|
|
||||||
backend_type = ProcessGroup.BackendType.CUSTOM
|
|
||||||
pg._register_backend(device, backend_type, backend_class)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user