Files
xc-llm-ascend/vllm_ascend/distributed/parallel_state.py
wemaster 0ae9ee0f8a [BUGFIX] main-sd-bugfix && [UT] add mtp UT (#593)
### What this PR does / why we need it?
The pr will fix some bug about spec decode / MTP
The pr add a mtp e2e UT `test_mtp_correctness.py`

**vllm_ascend/attention/attention.py**
1. add support `self.attn_mask_cache` only has 1 element to cover scene
in which both spec docode and chunked prefill are enabled.

**vllm_ascend/distributed/parallel_state.py**
1. remove 2 assert because spec decode worker would use init_worker
twice

**vllm_ascend/models/deepseek_mtp.py**
1. remove unused params;
2. add support w8a8 in `CustomDeepSeekMTP`

**vllm_ascend/quantization/quant_config.py**
1. use `AscendUnquantizedFusedMoEMethod` instead of
`UnquantizedFusedMoEMethod`

**other**
1. replace `from vllm.logger import init_logger` to `from vllm.logger
import logger` all of the vllm-ascend project



### Does this PR introduce _any_ user-facing change?


### How was this patch tested?

Signed-off-by: mengwei805 <mengwei25@huawei.com>
2025-04-21 19:25:51 +08:00

72 lines
2.3 KiB
Python

from typing import Optional
import torch
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for
# customize parallel solution
_EP: Optional[GroupCoordinator] = None
_ETP: Optional[list[GroupCoordinator]] = None
def get_ep_group() -> GroupCoordinator:
assert _EP is not None, ("expert model parallel group is not initialized")
return _EP
def get_etp_group() -> GroupCoordinator:
assert _ETP is not None, (
"expert tensor parallel group is not initialized")
return _ETP
def init_ascend_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
expert_tensor_parallel_size: int = 1,
backend: Optional[str] = None,
):
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
num_expert_parallel_groups: int = expert_tensor_parallel_size
num_expert_tensor_parallel_groups: int = (world_size //
expert_tensor_parallel_size)
global _EP
group_ranks = []
for i in range(num_expert_parallel_groups):
ranks = list(range(i, world_size, num_expert_parallel_groups))
group_ranks.append(ranks)
_EP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="ep")
group_ranks = []
global _ETP
for i in range(num_expert_tensor_parallel_groups):
ranks = list(
range(i * expert_tensor_parallel_size,
(i + 1) * expert_tensor_parallel_size))
group_ranks.append(ranks)
_ETP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="etp")
def destory_ascend_model_parallel():
global _EP
if _EP:
_EP.destroy()
_EP = None
global _ETP
if _ETP:
_ETP.destroy()
_ETP = None