### What this PR does / why we need it?
Fix test on pyhccl to 2 cards
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.10.0
- vLLM main:
0d0cc9e150
Signed-off-by: MengqingCao <cmq0113@163.com>
122 lines
4.0 KiB
Python
122 lines
4.0 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import multiprocessing
|
|
import os
|
|
|
|
import torch
|
|
from vllm.distributed.parallel_state import (get_world_group,
|
|
init_distributed_environment)
|
|
from vllm.utils import update_environment_variables
|
|
|
|
from tests.e2e.conftest import cleanup_dist_env_and_memory
|
|
from vllm_ascend.distributed.device_communicators.pyhccl import \
|
|
PyHcclCommunicator
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
multiprocessing.set_start_method("spawn", force=True)
|
|
|
|
|
|
def _worker_entry(env, fn):
|
|
# `multiprocessing.Process` cannot accept environment variables directly
|
|
# so we need to pass the environment variables as arguments
|
|
# and update the environment variables in the function
|
|
update_environment_variables(env)
|
|
|
|
rank = int(os.environ['RANK'])
|
|
local_rank = int(os.environ['LOCAL_RANK'])
|
|
word_size = int(os.environ['WORLD_SIZE'])
|
|
|
|
distributed_init_method = "tcp://localhost:12345"
|
|
|
|
device = torch.device(f"npu:{local_rank}")
|
|
torch.npu.set_device(device)
|
|
|
|
init_distributed_environment(
|
|
world_size=word_size,
|
|
rank=rank,
|
|
distributed_init_method=distributed_init_method,
|
|
local_rank=local_rank,
|
|
backend="hccl")
|
|
fn()
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
|
def distributed_run(fn, world_size):
|
|
number_of_processes = world_size
|
|
processes: list[multiprocessing.Process] = []
|
|
for i in range(number_of_processes):
|
|
env: dict[str, str] = {}
|
|
env['RANK'] = str(i)
|
|
env['LOCAL_RANK'] = str(i)
|
|
env['WORLD_SIZE'] = str(number_of_processes)
|
|
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
|
p = multiprocessing.Process(target=_worker_entry, args=(env, fn))
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
for p in processes:
|
|
p.join()
|
|
|
|
for p in processes:
|
|
assert p.exitcode == 0
|
|
|
|
|
|
def worker_fn():
|
|
pynccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
|
|
device=get_world_group().device)
|
|
tensor = torch.ones(16, 1024, 1024,
|
|
dtype=torch.float32).npu(pynccl_comm.rank)
|
|
tensor = pynccl_comm.all_reduce(tensor)
|
|
torch.npu.synchronize()
|
|
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
|
|
|
|
|
def test_pyhccl():
|
|
distributed_run(worker_fn, 2)
|
|
|
|
|
|
def broadcast_worker_fn():
|
|
# Test broadcast for every root rank.
|
|
# Essentially this is an all-gather operation.
|
|
pyhccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
|
|
device=get_world_group().device)
|
|
recv_tensors = [
|
|
torch.empty(16,
|
|
1024,
|
|
1024,
|
|
dtype=torch.float32,
|
|
device=pyhccl_comm.device)
|
|
for i in range(pyhccl_comm.world_size)
|
|
]
|
|
recv_tensors[pyhccl_comm.rank] = torch.ones(
|
|
16, 1024, 1024, dtype=torch.float32,
|
|
device=pyhccl_comm.device) * pyhccl_comm.rank
|
|
|
|
for i in range(pyhccl_comm.world_size):
|
|
pyhccl_comm.broadcast(recv_tensors[i], src=i)
|
|
# the broadcast op might be launched in a different stream
|
|
# need to synchronize to make sure the tensor is ready
|
|
torch.npu.synchronize()
|
|
assert torch.all(recv_tensors[i] == i).cpu().item()
|
|
|
|
|
|
def test_pyhccl_broadcast():
|
|
distributed_run(broadcast_worker_fn, 2)
|