Files
xc-llm-ascend/vllm_ascend/distributed/parallel_state.py
NINGBENZHE 6ec64a3f96 [bugfix] some bugs maybe fail to run (#896)
### What this PR does / why we need it?
Solve the bug that the graph mode is the same as p and d, and some other
bugs.
### Does this PR introduce _any_ user-facing change?
Wouldn't be
### How was this patch tested?
Follow the end-to-end test

Signed-off-by: ningbenzhe1 <ningbenzhe@huawei.com>
2025-06-03 11:07:33 +08:00

78 lines
2.3 KiB
Python

from typing import Optional
import torch
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for
# customize parallel solution
_EP: Optional[GroupCoordinator] = None
_ETP: Optional[GroupCoordinator] = None
def get_ep_group() -> GroupCoordinator:
assert _EP is not None, ("expert model parallel group is not initialized")
return _EP
def get_etp_group() -> GroupCoordinator:
assert _ETP is not None, (
"expert tensor parallel group is not initialized")
return _ETP
def model_parallel_initialized():
return (_ETP is not None and _EP is not None)
def init_ascend_model_parallel(
expert_parallel_size: int = 1,
expert_tensor_parallel_size: int = 1,
world_size: Optional[int] = None,
backend: Optional[str] = None,
):
if model_parallel_initialized():
return
assert torch.distributed.is_initialized()
world_size = world_size or torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
num_expert_parallel_groups = expert_tensor_parallel_size
num_expert_tensor_parallel_groups = expert_parallel_size
global _EP
group_ranks = []
for i in range(num_expert_parallel_groups):
ranks = list(range(i, world_size, num_expert_parallel_groups))
group_ranks.append(ranks)
_EP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="ep")
group_ranks = []
global _ETP
for i in range(num_expert_tensor_parallel_groups):
ranks = list(
range(i * expert_tensor_parallel_size,
(i + 1) * expert_tensor_parallel_size))
group_ranks.append(ranks)
_ETP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="etp")
def destory_ascend_model_parallel():
global _EP
if _EP:
_EP.destroy()
_EP = None
global _ETP
if _ETP:
_ETP.destroy()
_ETP = None