[Nightly] Move ops to the correct path (#5642)
### What this PR does / why we need it?
Move ops to the correct path where they belong
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -1,168 +0,0 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch_npu
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
class TestDisptachFFNCombine:
|
||||
|
||||
def __init__(self, rank, world_size, port):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.master_ip = "127.0.0.1"
|
||||
self.port = port
|
||||
|
||||
def get_hcomm(self, comm_group):
|
||||
hcomm_info = None
|
||||
if torch.__version__ > "2.0.1":
|
||||
hcomm_info = comm_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(self.rank)
|
||||
else:
|
||||
hcomm_info = comm_group.get_hccl_comm_name(self.rank)
|
||||
return hcomm_info
|
||||
|
||||
def setup_ep_tp(
|
||||
self,
|
||||
rank,
|
||||
tp_size,
|
||||
ep_size,
|
||||
backend_type,
|
||||
ep_ranks_list=None,
|
||||
tp_ranks_list=None,
|
||||
):
|
||||
for i in range(tp_size):
|
||||
if ep_ranks_list:
|
||||
ep_ranks = ep_ranks_list[i]
|
||||
else:
|
||||
ep_ranks = [x + ep_size * i for x in range(ep_size)]
|
||||
ep_group = dist.new_group(backend=backend_type, ranks=ep_ranks)
|
||||
if rank in ep_ranks:
|
||||
ep_group_tmp = ep_group
|
||||
for i in range(ep_size):
|
||||
if tp_ranks_list:
|
||||
tp_ranks = tp_ranks_list[i]
|
||||
else:
|
||||
tp_ranks = [x * ep_size + i for x in range(tp_size)]
|
||||
tp_group = dist.new_group(backend=backend_type, ranks=tp_ranks)
|
||||
if rank in tp_ranks:
|
||||
tp_group_tmp = tp_group
|
||||
return ep_group_tmp, tp_group_tmp
|
||||
|
||||
def generate_hcom(self):
|
||||
torch_npu.npu.set_device(self.rank)
|
||||
dist.init_process_group(
|
||||
backend="hccl",
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
init_method=f"tcp://127.0.0.1:{self.port}",
|
||||
)
|
||||
|
||||
ep_size = 0
|
||||
tp_size = self.world_size
|
||||
hcomm_info_dist = {
|
||||
"default_pg_info": None,
|
||||
"ep_hcomm_info": None,
|
||||
"group_ep": None,
|
||||
"tp_hcomm_info": None,
|
||||
"group_tp": None,
|
||||
}
|
||||
if ep_size and tp_size:
|
||||
group_ep, group_tp = self.setup_ep_tp(self.rank, tp_size, ep_size,
|
||||
"hccl", None, None)
|
||||
hcomm_info_dist["ep_hcomm_info"] = self.get_hcomm(group_ep)
|
||||
hcomm_info_dist["tp_hcomm_info"] = self.get_hcomm(group_tp)
|
||||
hcomm_info_dist["group_ep"] = group_ep
|
||||
hcomm_info_dist["group_tp"] = group_tp
|
||||
else:
|
||||
if dist.is_available():
|
||||
default_pg = _get_default_group()
|
||||
hcomm_info_dist["default_pg_info"] = self.get_hcomm(default_pg)
|
||||
hcomm_info = hcomm_info_dist["default_pg_info"]
|
||||
self.hcomm_info = hcomm_info
|
||||
|
||||
def run_npu_out(self) -> bool:
|
||||
torch_npu.npu.set_device(self.rank)
|
||||
m = 2 # token-num 32
|
||||
k = 4 # hidden_size 7168
|
||||
n = 4 # mid-hidden-size 4096
|
||||
topk = 2
|
||||
e = 2 # expert-num-per-rank 16
|
||||
k2 = n // 2
|
||||
n2 = k
|
||||
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
x = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||
weight1 = self.generate_random_tensor((e, k, n),
|
||||
dtype=torch.int8).npu()
|
||||
weight1 = torch_npu.npu_format_cast(weight1, 29)
|
||||
weight2 = self.generate_random_tensor((e, k2, n2),
|
||||
dtype=torch.int8).npu()
|
||||
weight2 = torch_npu.npu_format_cast(weight2, 29)
|
||||
|
||||
expert_idx = torch.randint(0,
|
||||
self.world_size * e, (m, topk),
|
||||
dtype=torch.int32).npu()
|
||||
scale1 = torch.randint(0, 1, (e, n), dtype=torch.int64).npu()
|
||||
scale2 = torch.randint(0, 1, (e, n2), dtype=torch.int64).npu()
|
||||
probs = torch.randn(size=(m, topk), dtype=torch.float32).npu()
|
||||
out = self.generate_random_tensor((m, k), dtype=torch.bfloat16).npu()
|
||||
|
||||
torch.ops._C_ascend.dispatch_ffn_combine(
|
||||
x=x,
|
||||
weight1=weight1,
|
||||
weight2=weight2,
|
||||
expert_idx=expert_idx,
|
||||
scale1=scale1,
|
||||
scale2=scale2,
|
||||
probs=probs,
|
||||
group=self.hcomm_info,
|
||||
max_output_size=512,
|
||||
out=out,
|
||||
)
|
||||
return True
|
||||
|
||||
def generate_random_tensor(self, size, dtype):
|
||||
if dtype in [torch.float16, torch.bfloat16, torch.float32]:
|
||||
return torch.randn(size=size, dtype=dtype)
|
||||
elif dtype is torch.int8:
|
||||
return torch.randint(-16, 16, size=size, dtype=dtype)
|
||||
elif dtype is torch.int32:
|
||||
return torch.randint(-1024, 1024, size=size, dtype=dtype)
|
||||
else:
|
||||
raise ValueError(f"Invalid dtype: {dtype}")
|
||||
|
||||
|
||||
def worker(rank: int, world_size: int, port: int, q: mp.SimpleQueue):
|
||||
op = TestDisptachFFNCombine(rank, world_size, port)
|
||||
op.generate_hcom()
|
||||
out = op.run_npu_out()
|
||||
q.put(out)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dispatch_ffn_combine_kernel():
|
||||
world_size = 2
|
||||
mp.set_start_method("fork", force=True)
|
||||
|
||||
q = mp.SimpleQueue()
|
||||
p_list = []
|
||||
port = 29501 + random.randint(0, 10000)
|
||||
|
||||
for rank in range(world_size):
|
||||
p = mp.Process(target=worker, args=(rank, world_size, port, q))
|
||||
p.start()
|
||||
p_list.append(p)
|
||||
|
||||
results = [q.get() for _ in range(world_size)]
|
||||
|
||||
for p in p_list:
|
||||
p.join()
|
||||
|
||||
assert all(results)
|
||||
@@ -1,135 +0,0 @@
|
||||
import gc
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch_npu
|
||||
import torchair
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.mode = "reduce-overhead"
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
torch_npu.npu.config.allow_internal_format = True
|
||||
enable_custom_op()
|
||||
|
||||
global_rank_id = 0
|
||||
|
||||
|
||||
def golden_op_matmul_allreduce_add_rmsnorm(a, b, residual, gamma, epsilon):
|
||||
c_ret = torch.nn.functional.linear(a, b)
|
||||
dist.all_reduce(c_ret)
|
||||
rmsnorm_ret, _, add_ret = torch_npu.npu_add_rms_norm(
|
||||
c_ret, residual, gamma, epsilon)
|
||||
return rmsnorm_ret, add_ret
|
||||
|
||||
|
||||
def worker(rank, ep_world_size, batch_size, m, k, n):
|
||||
global global_rank_id
|
||||
global_rank_id = rank
|
||||
rank = rank
|
||||
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "29500"
|
||||
dist.init_process_group(backend="hccl",
|
||||
rank=rank,
|
||||
world_size=ep_world_size)
|
||||
|
||||
ep_ranks_list = list(np.arange(0, ep_world_size))
|
||||
|
||||
ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list)
|
||||
|
||||
torch_npu.npu.set_device(rank)
|
||||
ep_hcomm_info = ep_group._get_backend(
|
||||
torch.device("npu")).get_hccl_comm_name(rank)
|
||||
|
||||
torch_npu.npu.synchronize(rank)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x1, x2, residual, gamma, ep_hcomm_info, epsilon,
|
||||
is_trans_b, is_allgather_add_out):
|
||||
out1, add_out1 = torch.ops._C_ascend.matmul_allreduce_add_rmsnorm(
|
||||
x1, x2, residual, gamma, ep_hcomm_info, ep_world_size,
|
||||
global_rank_id, epsilon, is_trans_b, is_allgather_add_out)
|
||||
return out1, add_out1
|
||||
|
||||
DTYPE = torch.bfloat16
|
||||
USE_ONES = False
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
if USE_ONES:
|
||||
x1 = torch.ones([m, k], dtype=DTYPE).npu(rank)
|
||||
x2 = torch.ones([n, k], dtype=DTYPE).npu(rank)
|
||||
else:
|
||||
x1 = torch.normal(0, 0.1, [m, k], dtype=DTYPE).npu(rank)
|
||||
x2 = torch.normal(0, 0.1, [n, k], dtype=DTYPE).npu(rank)
|
||||
|
||||
if USE_ONES:
|
||||
residual = torch.full([m, n], 2048, dtype=DTYPE).npu(rank)
|
||||
else:
|
||||
residual = torch.full([m, n], 0, dtype=DTYPE).npu(rank)
|
||||
|
||||
gamma = torch.full([n], 1, dtype=DTYPE).npu(rank)
|
||||
|
||||
epsilon = 1e-5
|
||||
is_trans_b = True
|
||||
is_allgather_add_out = True
|
||||
warnup_cnt = 5
|
||||
repeat_cnt = 10
|
||||
|
||||
def run_golden_case(loop_cnt):
|
||||
for _ in range(loop_cnt):
|
||||
golden_out, golden_add_out = golden_op_matmul_allreduce_add_rmsnorm(
|
||||
x1, x2, residual, gamma, epsilon)
|
||||
torch_npu.npu.synchronize(rank)
|
||||
return golden_out, golden_add_out
|
||||
|
||||
run_golden_case(warnup_cnt)
|
||||
|
||||
golden_out, golden_add_out = run_golden_case(repeat_cnt)
|
||||
golden_out = golden_out.detach().cpu()
|
||||
golden_add_out = golden_add_out.detach().cpu()
|
||||
|
||||
mod = Module().npu()
|
||||
opt_model = torch.compile(mod, backend=npu_backend)
|
||||
|
||||
def run_custom_case(loop_cnt):
|
||||
for _ in range(loop_cnt):
|
||||
out, add_out = opt_model(x1, x2, residual, gamma, ep_hcomm_info,
|
||||
epsilon, is_trans_b, is_allgather_add_out)
|
||||
torch_npu.npu.synchronize(rank)
|
||||
return out, add_out
|
||||
|
||||
# warn up
|
||||
run_custom_case(warnup_cnt)
|
||||
|
||||
out, add_out = run_custom_case(repeat_cnt)
|
||||
out = out.detach().cpu()
|
||||
add_out = add_out.detach().cpu()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
torch.testing.assert_close(golden_out, out, atol=0.1, rtol=0.005)
|
||||
torch.testing.assert_close(golden_add_out, add_out, atol=0.1, rtol=0.005)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_matmul_allreduce_add_rmsnorm_kernel():
|
||||
ep_world_size = 4
|
||||
batch_size = 1
|
||||
m = 10000
|
||||
k = 1024
|
||||
n = 5120
|
||||
args = (ep_world_size, batch_size, m, k, n)
|
||||
mp.spawn(worker, args=args, nprocs=ep_world_size, join=True)
|
||||
Reference in New Issue
Block a user