Files
xc-llm-ascend/tests/multicard/test_pyhccl_distributed.py
wangxiyuan b917361ca5 [MISC] Clean up torch_npu (#688)
torch_npu 2.5.1 support autoload now. This patch does:
1. remove useless torch_npu import
2. replace `torch_npu.npu` to `torch.npu`.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-04-29 18:03:38 +08:00

111 lines
3.7 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import multiprocessing
import os
import torch
from vllm.distributed.parallel_state import (get_world_group,
init_distributed_environment)
from vllm.utils import update_environment_variables
from vllm_ascend.distributed.device_communicators.pyhccl import \
PyHcclCommunicator
def distributed_run(fn, world_size):
number_of_processes = world_size
processes: list[multiprocessing.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, ))
processes.append(p)
p.start()
for p in processes:
p.join()
for p in processes:
assert p.exitcode == 0
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ['LOCAL_RANK']
device = torch.device(f"npu:{local_rank}")
torch.npu.set_device(device)
init_distributed_environment(backend="hccl")
fn()
return wrapped_fn
@worker_fn_wrapper
def worker_fn():
pynccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).npu(pynccl_comm.rank)
tensor = pynccl_comm.all_reduce(tensor)
torch.npu.synchronize()
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
# def test_pyhccl():
# distributed_run(worker_fn, 2)
@worker_fn_wrapper
def broadcast_worker_fn():
# Test broadcast for every root rank.
# Essentially this is an all-gather operation.
pyhccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
recv_tensors = [
torch.empty(16,
1024,
1024,
dtype=torch.float32,
device=pyhccl_comm.device)
for i in range(pyhccl_comm.world_size)
]
recv_tensors[pyhccl_comm.rank] = torch.ones(
16, 1024, 1024, dtype=torch.float32,
device=pyhccl_comm.device) * pyhccl_comm.rank
for i in range(pyhccl_comm.world_size):
pyhccl_comm.broadcast(recv_tensors[i], src=i)
# the broadcast op might be launched in a different stream
# need to synchronize to make sure the tensor is ready
torch.npu.synchronize()
assert torch.all(recv_tensors[i] == i).cpu().item()
# def test_pyhccl_broadcast():
# distributed_run(broadcast_worker_fn, 4)