This pr makes format.sh works as expect. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
|
init_model_parallel_group)
|
|
|
|
# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for
|
|
# customize parallel solution
|
|
_EP: Optional[GroupCoordinator] = None
|
|
_ETP: Optional[GroupCoordinator] = None
|
|
|
|
|
|
def get_ep_group() -> GroupCoordinator:
|
|
assert _EP is not None, ("expert model parallel group is not initialized")
|
|
return _EP
|
|
|
|
|
|
def get_etp_group() -> GroupCoordinator:
|
|
assert _ETP is not None, (
|
|
"expert tensor parallel group is not initialized")
|
|
return _ETP
|
|
|
|
|
|
def init_ascend_model_parallel(
|
|
tensor_model_parallel_size: int = 1,
|
|
pipeline_model_parallel_size: int = 1,
|
|
expert_tensor_parallel_size: int = 1,
|
|
backend: Optional[str] = None,
|
|
):
|
|
assert torch.distributed.is_initialized()
|
|
world_size: int = torch.distributed.get_world_size()
|
|
backend = backend or torch.distributed.get_backend(
|
|
get_world_group().device_group)
|
|
num_expert_parallel_groups: int = expert_tensor_parallel_size
|
|
num_expert_tensor_parallel_groups: int = (world_size //
|
|
expert_tensor_parallel_size)
|
|
|
|
global _EP
|
|
group_ranks = []
|
|
for i in range(num_expert_parallel_groups):
|
|
ranks = list(range(i, world_size, num_expert_parallel_groups))
|
|
group_ranks.append(ranks)
|
|
|
|
_EP = init_model_parallel_group(group_ranks,
|
|
get_world_group().local_rank,
|
|
backend,
|
|
group_name="ep")
|
|
|
|
group_ranks = []
|
|
global _ETP
|
|
for i in range(num_expert_tensor_parallel_groups):
|
|
ranks = list(
|
|
range(i * expert_tensor_parallel_size,
|
|
(i + 1) * expert_tensor_parallel_size))
|
|
group_ranks.append(ranks)
|
|
|
|
_ETP = init_model_parallel_group(group_ranks,
|
|
get_world_group().local_rank,
|
|
backend,
|
|
group_name="etp")
|
|
|
|
|
|
def destory_ascend_model_parallel():
|
|
global _EP
|
|
if _EP:
|
|
_EP.destroy()
|
|
_EP = None
|
|
|
|
global _ETP
|
|
if _ETP:
|
|
_ETP.destroy()
|
|
_ETP = None
|