Files
xc-llm-ascend/tests/ut/distributed/test_parallel_state.py
zzhxxx f7db812ed7 [refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (#5181)
### What this PR does / why we need it?
- Delete the environment variable
`VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED`
- Introduce layer_sharding as a configurable feature in
additional_config
- Revise the term "shared weight" to "shard weight."
Configuration : The feature is opt-in via the additional_config
argument:
```
--additional-config '{
  "layer_sharding": ["o_proj", "q_b_proj"]
}'
```

This is orthogonal to standard tensor parallelism and weight replication
strategies. It is treated as a separate, explicit feature.It can be used
in any scenario, combined with the
flashcomm2https://github.com/vllm-project/vllm-ascend/pull/3232 feature
or the ShardedCP #4702 feature, to achieve significant performance.



- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhxx <zhangzihang23@mails.ucas.ac.cn>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
2026-01-08 09:05:02 +08:00

78 lines
3.4 KiB
Python

from unittest.mock import MagicMock, patch
import pytest
from vllm.config import ParallelConfig
from vllm_ascend.distributed.parallel_state import (
_FLASHCOMM2_ODP, _FLASHCOMM2_OTP, _LMTP, _MC2, _OTP, _P_TP,
destroy_ascend_model_parallel, get_flashcomm2_odp_group,
get_flashcomm2_otp_group, get_lmhead_tp_group, get_mc2_group,
get_otp_group, get_p_tp_group, init_ascend_model_parallel)
@pytest.fixture
def parallel_config():
return ParallelConfig(
data_parallel_size=2,
tensor_parallel_size=4,
pipeline_parallel_size=2,
)
@pytest.fixture
def mock_distributed():
with patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.get_world_size', return_value=16), \
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
mock_tp_group.return_value.world_size = 4
yield
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config = MagicMock()
mock_ascend_config.finegrained_tp_config.lmhead_tensor_parallel_size = 2
mock_ascend_config.finegrained_tp_config.oproj_tensor_parallel_size = 2
mock_ascend_config.finegrained_tp_config.embedding_tensor_parallel_size = 2
mock_ascend_config.finegrained_tp_config.mlp_tensor_parallel_size = 2
mock_ascend_config.flashcomm2_oproj_tensor_parallel_size = 2
mock_ascend_config.pd_tp_ratio = 2
mock_ascend_config.num_head_replica = 0
mock_ascend_config.pd_head_ratio = 2
mock_vllm_config = MagicMock()
mock_vllm_config.kv_transfer_config.is_kv_producer = True
mock_envs_ascend = MagicMock()
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
patch('vllm_ascend.distributed.parallel_state.get_current_vllm_config', return_value=mock_vllm_config), \
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config), \
patch('vllm_ascend.utils.envs_ascend', new=mock_envs_ascend), \
patch('vllm_ascend.utils.get_ascend_config', return_value=mock_ascend_config):
init_ascend_model_parallel(parallel_config)
mc2_group = get_mc2_group()
lmheadtp_group = get_lmhead_tp_group()
otp_group = get_otp_group()
flashcomm2_otp_group = get_flashcomm2_otp_group()
flashcomm2_odp_group = get_flashcomm2_odp_group()
p_tp_group = get_p_tp_group()
assert mc2_group is not None
assert otp_group is not None
assert flashcomm2_otp_group is not None
assert flashcomm2_odp_group is not None
assert lmheadtp_group is not None
assert p_tp_group is not None
destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None
assert _OTP is None
assert _FLASHCOMM2_OTP is None
assert _FLASHCOMM2_ODP is None
assert _P_TP is None