add qwen3
This commit is contained in:
127
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_copy_blocks.py
Normal file
127
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_copy_blocks.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_mlu
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
DTYPES = [torch.half, torch.float]
|
||||
if "3" not in torch.mlu.get_device_name(0):
|
||||
DTYPES = [torch.half, torch.float]
|
||||
NUM_TOKENS = [83] # Arbitrary values for testing
|
||||
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||
NUM_HEADS = [8] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256, 512]
|
||||
BLOCK_SIZES = [8, 16, 32]
|
||||
NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.mlu.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
def create_kv_caches(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = head_size**-0.5
|
||||
# vllm scale
|
||||
# x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
# key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache_shape = (num_blocks, num_heads, block_size, head_size)
|
||||
print("key_cache_shape: ", key_cache_shape)
|
||||
key_caches = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype).mlu()
|
||||
if dtype == torch.int32 or dtype == torch.int64:
|
||||
key_cache.random_(-100,100)
|
||||
else:
|
||||
key_cache.uniform_(-scale, scale)
|
||||
key_caches.append(key_cache)
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
print("value_cache_shape: ", value_cache_shape)
|
||||
value_caches = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype).mlu()
|
||||
if dtype == torch.int32 or dtype == torch.int64:
|
||||
value_cache.random_(-100,100)
|
||||
else:
|
||||
value_cache.uniform_(-scale, scale)
|
||||
value_caches.append(value_cache)
|
||||
return key_caches, value_caches
|
||||
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks(
|
||||
num_mappings: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.mlu.manual_seed(seed)
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
assert 3 * num_mappings <= num_blocks
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||
block_mapping = torch.empty(num_mappings, 2, dtype=torch.int32)
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst = dst_blocks[2 * i]
|
||||
block_mapping[i] = torch.tensor([src, dst])
|
||||
block_mapping_mlu = block_mapping.mlu()
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = create_kv_caches(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, dtype, seed)
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
mlu_ops.copy_blocks(key_caches, value_caches, block_mapping_mlu)
|
||||
|
||||
# Run the reference implementation.
|
||||
for mapping in block_mapping:
|
||||
src, dst = mapping[0], mapping[1]
|
||||
for cloned_key_cache in cloned_key_caches:
|
||||
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
||||
for cloned_value_cache in cloned_value_caches:
|
||||
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
assert torch.allclose(key_cache, cloned_key_cache)
|
||||
for value_cache, cloned_value_cache in zip(value_caches,
|
||||
cloned_value_caches):
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
99
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_ffn.py
Normal file
99
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_ffn.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_mlu
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
act_dict = {
|
||||
"relu": F.relu,
|
||||
"gelu": F.gelu,
|
||||
"silu": F.silu,
|
||||
}
|
||||
|
||||
def ref_ffn(
|
||||
hidden_states,
|
||||
up_fc_weight,
|
||||
up_fc_bias,
|
||||
down_proj_weight,
|
||||
down_proj_bias,
|
||||
gate_up_proj_weight,
|
||||
gate_up_proj_bias,
|
||||
layernorm_weight,
|
||||
layernorm_bias,
|
||||
act_mode):
|
||||
up_output = F.linear(hidden_states, up_fc_weight, bias=up_fc_bias)
|
||||
act_output = act_dict[act_mode](up_output)
|
||||
if not gate_up_proj_weight is None:
|
||||
gate_output = F.linear(hidden_states, gate_up_proj_weight, bias=gate_up_proj_bias)
|
||||
out = F.linear(act_output * gate_output, down_proj_weight, bias=down_proj_bias)
|
||||
else:
|
||||
out = F.linear(act_output, down_proj_weight, bias=down_proj_bias)
|
||||
return out
|
||||
|
||||
BATCH_SIZE = [1]
|
||||
SEQ_LENS = [1, 64, 1024]
|
||||
HIDDEN_SIZE = [16, 24]
|
||||
INTER_SIZE = [32]
|
||||
DTYPES = [torch.half, torch.float]
|
||||
if "3" not in torch.mlu.get_device_name(0):
|
||||
DTYPES = [torch.half, torch.float]
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZE)
|
||||
@pytest.mark.parametrize("inter_size", INTER_SIZE)
|
||||
@pytest.mark.parametrize("act_name", ["relu", "silu"]) # gelu
|
||||
@pytest.mark.parametrize("use_gate", [True])
|
||||
@pytest.mark.parametrize("use_gate_bias", [False])
|
||||
@pytest.mark.parametrize("use_up_bias", [False])
|
||||
@pytest.mark.parametrize("use_down_bias", [False])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", [0])
|
||||
def test_attention_project(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
inter_size: int,
|
||||
act_name: str,
|
||||
use_gate: bool,
|
||||
use_gate_bias: bool,
|
||||
use_up_bias: bool,
|
||||
use_down_bias: bool,
|
||||
dtype: torch.dtype,
|
||||
seed : int
|
||||
) -> None:
|
||||
device_id = "mlu:0"
|
||||
torch.random.manual_seed(seed)
|
||||
torch.mlu.manual_seed(seed)
|
||||
|
||||
hidden_states = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device_id)
|
||||
up_proj_weight= torch.randn(inter_size, hidden_size, dtype=dtype, device=device_id)
|
||||
if use_gate:
|
||||
gate_proj_weight = torch.randn(inter_size, hidden_size, dtype=dtype, device=device_id)
|
||||
else:
|
||||
gate_proj_weight = None
|
||||
down_proj_weight = torch.randn(hidden_size, inter_size, dtype=dtype, device=device_id)
|
||||
|
||||
out = mlu_ops.ffn(hidden_states,
|
||||
up_proj_weight,
|
||||
None,
|
||||
down_proj_weight,
|
||||
None,
|
||||
gate_proj_weight,
|
||||
None,
|
||||
act_name)
|
||||
|
||||
ref_out = ref_ffn(
|
||||
hidden_states,
|
||||
up_proj_weight,
|
||||
None,
|
||||
down_proj_weight,
|
||||
None,
|
||||
gate_proj_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
act_name
|
||||
)
|
||||
|
||||
assert torch.allclose(out, ref_out, atol=1e-1, rtol=1e-1)
|
||||
|
||||
185
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_rotary_emb.py
Normal file
185
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_rotary_emb.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import numpy
|
||||
from typing import List, Optional
|
||||
from itertools import accumulate
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope, _ROPE_DICT, LinearScalingRotaryEmbedding
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding import MLURotaryEmbedding
|
||||
|
||||
ROPE_THRESHOLD_DIFF1 = 5e-3
|
||||
ROPE_THRESHOLD_DIFF2 = 5e-3
|
||||
|
||||
def compute_diff(baseline: numpy.ndarray, compare: numpy.ndarray):
|
||||
'''add diff1 diff2 accuracy method'''
|
||||
error = numpy.abs(baseline - compare)
|
||||
diff1 = numpy.sum(error) / numpy.sum(numpy.abs(baseline))
|
||||
diff2 = numpy.sqrt(numpy.sum(error**2)/numpy.sum(baseline**2))
|
||||
return diff1, diff2
|
||||
|
||||
IS_NEOX_STYLE = [True, False]
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
ROTARY_DIMS = [32] # None means rotary dim == head size
|
||||
NUM_HEADS = [9, 17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [1, 5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 8192]
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
max_position: int = 128,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
total_seq_len = batch_size * seq_len
|
||||
|
||||
MLURotaryEmbedding.max_seq_len = max_position
|
||||
rope = MLURotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype)
|
||||
rope = rope.to(dtype=dtype, device=0)
|
||||
|
||||
positions = torch.randint(0,
|
||||
max_position, ([batch_size*seq_len]),
|
||||
device=0).to(dtype=torch.int32)
|
||||
|
||||
context_shape = (total_seq_len, num_heads, head_size)
|
||||
context = torch.randn(size=context_shape, dtype=dtype).mlu()
|
||||
qk = context[..., 0 : num_heads, :]
|
||||
ref_qk = qk.clone()
|
||||
|
||||
cu_seq_lens = torch.arange(0, batch_size + 1, dtype=torch.int32).mlu() * seq_len
|
||||
MLURotaryEmbedding.set_cos_sin = False
|
||||
MLURotaryEmbedding.cu_seq_lens = cu_seq_lens
|
||||
MLURotaryEmbedding.is_prompt = False
|
||||
MLURotaryEmbedding.is_chunked = False
|
||||
qk_out = rope.forward(positions, qk)
|
||||
rope_base = MLURotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype)
|
||||
# for simular CPU re_init_cos_sin_cache
|
||||
if "cos_sin_cache" in rope_base._buffers:
|
||||
del rope_base._buffers["cos_sin_cache"]
|
||||
cache = rope_base._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype).mlu()
|
||||
rope_base.cos_sin_cache: torch.Tensor
|
||||
rope_base.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
num_q_heads = num_heads - 1
|
||||
ref_q = ref_qk[:, :num_q_heads, :].reshape(-1,head_size)
|
||||
ref_k = ref_qk[:, num_q_heads:, :].reshape(-1,head_size)
|
||||
ref_q_o, ref_k_o = rope_base.forward_native(positions, ref_q, ref_k)
|
||||
ref_q_o_reshape = ref_q_o.reshape(-1, num_q_heads, head_size)
|
||||
ref_k_o_reshape = ref_k_o.reshape(-1, 1, head_size)
|
||||
ref_qk_out = torch.cat((ref_q_o_reshape, ref_k_o_reshape), dim=1).cpu()
|
||||
qk_out_cpu = qk_out.cpu()
|
||||
MLURotaryEmbedding.unset_mlu_var()
|
||||
diff1, diff2 = compute_diff(baseline=ref_qk_out.float().detach().numpy(),
|
||||
compare=qk_out_cpu.float().detach().numpy())
|
||||
assert diff1 <= ROPE_THRESHOLD_DIFF1 and diff2 <= ROPE_THRESHOLD_DIFF2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", [1, 11, 1024])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [64, 80, 128])
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_batched_rotary_embedding_multi_lora(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
device: str = "mlu",
|
||||
seed: int = 0,
|
||||
max_position: int = 4096,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
"""test linear scaling rope kernel"""
|
||||
assert device == "mlu"
|
||||
assert torch.mlu.is_available()
|
||||
|
||||
torch.random.manual_seed(seed)
|
||||
torch.mlu.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
is_prompt = seq_len == 1
|
||||
scaling_factors: List[int] = [1, 2, 4]
|
||||
|
||||
MLURotaryEmbedding.max_seq_len = max_position
|
||||
MLURotaryEmbedding.set_cos_sin = False
|
||||
MLURotaryEmbedding.is_prompt = is_prompt
|
||||
MLURotaryEmbedding.is_chunked = False
|
||||
MLURotaryEmbedding.positions_ = None
|
||||
MLURotaryEmbedding.cu_seq_lens = seq_len * torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32)
|
||||
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||
"rope_type": "linear",
|
||||
"factor": tuple(scaling_factors)
|
||||
})
|
||||
rope = rope.to(dtype=dtype)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size * seq_len, ),
|
||||
dtype=torch.int32)
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
offset_map = torch.tensor(
|
||||
list(
|
||||
accumulate([0] + [
|
||||
max_position * scaling_factor * 2
|
||||
for scaling_factor in scaling_factors[:-1]
|
||||
])), dtype=torch.int32)
|
||||
query_types = torch.randint(0,
|
||||
len(scaling_factors), (batch_size * seq_len, ))
|
||||
query_offsets = offset_map[query_types]
|
||||
|
||||
qk = torch.cat([query, key], dim=-1)
|
||||
qk = qk.view(batch_size * seq_len, num_heads + num_heads, head_size)
|
||||
|
||||
out_qk = rope.forward(positions, qk, query_offsets)
|
||||
|
||||
scaling_factor = tuple(scaling_factors)
|
||||
rope_base = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
torch.get_default_dtype())
|
||||
|
||||
ref_query, ref_key = rope_base.forward_native(positions, query, key,
|
||||
query_offsets)
|
||||
ref_qk = torch.cat([ref_query, ref_key], dim=-1)
|
||||
ref_qk = ref_qk.view(batch_size * seq_len, num_heads + num_heads, head_size)
|
||||
|
||||
# delete rope cache to init rope instance every time
|
||||
_ROPE_DICT.clear()
|
||||
|
||||
# compare the results
|
||||
diff1, diff2 = compute_diff(baseline=ref_qk.cpu().float().detach().numpy(),
|
||||
compare=out_qk.cpu().float().detach().numpy())
|
||||
assert diff1 <= ROPE_THRESHOLD_DIFF1 and diff2 <= ROPE_THRESHOLD_DIFF2
|
||||
94
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_swap_blocks.py
Normal file
94
vllm-v0.6.2/tests/kernels/bt_torch_ops/test_swap_blocks.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch_mlu
|
||||
USE_CUDA=False
|
||||
USE_MLU=True
|
||||
if USE_CUDA:
|
||||
from vllm._C import cache_ops
|
||||
if USE_MLU:
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
DTYPES = [torch.half, torch.float]
|
||||
if "3" not in torch.mlu.get_device_name(0):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.mlu.device_count() == 1 else 2)]
|
||||
num_N = [3600]
|
||||
num_C = [8]
|
||||
num_H = [32,128]
|
||||
num_W = [16]
|
||||
num_pairs = [3,256]
|
||||
cpys = ["mlu to mlu", "mlu to cpu", "cpu to mlu"]
|
||||
|
||||
class SwapBlocks(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self,
|
||||
dst: torch.Tensor,
|
||||
src: torch.Tensor,
|
||||
src_to_dst: dict):
|
||||
for key, value in src_to_dst.items():
|
||||
dst[value] = src[key]
|
||||
|
||||
@pytest.mark.parametrize("n", num_N)
|
||||
@pytest.mark.parametrize("c", num_C)
|
||||
@pytest.mark.parametrize("h", num_H)
|
||||
@pytest.mark.parametrize("w", num_W)
|
||||
@pytest.mark.parametrize("num_pair", num_pairs)
|
||||
@pytest.mark.parametrize("cpy", cpys)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks(
|
||||
n,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
num_pair: int,
|
||||
cpy: str,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.mlu.manual_seed(seed)
|
||||
|
||||
if cpy == "mlu to mlu":
|
||||
src = torch.randn(n, c, h, w, dtype=dtype).mlu()
|
||||
dst = torch.randn(n, c, h, w, dtype=dtype).mlu()
|
||||
elif cpy == "mlu to cpu":
|
||||
src = torch.randn(n, c, h, w, dtype=dtype).mlu()
|
||||
dst = torch.randn(n, c, h, w, dtype=dtype).cpu()
|
||||
elif cpy == "cpu to mlu":
|
||||
src = torch.randn(n, c, h, w, dtype=dtype).cpu()
|
||||
dst = torch.randn(n, c, h, w, dtype=dtype).mlu()
|
||||
else:
|
||||
print("unkown copy direction.")
|
||||
exit(1)
|
||||
|
||||
values = list(range(num_pair))
|
||||
random.shuffle(values)
|
||||
src_to_dst = {key: value for key, value in zip(range(num_pair), values)}
|
||||
|
||||
mapping_data = []
|
||||
for k, v in src_to_dst.items():
|
||||
mapping_data.append([k, v])
|
||||
src_to_dst_tensor = torch.tensor(mapping_data, dtype=torch.int32).mlu()
|
||||
|
||||
ref_src, ref_dst = src.clone(), dst.clone()
|
||||
swap_blocks = SwapBlocks()
|
||||
# Call the swap blocks kernel.
|
||||
# cpu
|
||||
swap_blocks(ref_dst, ref_src, src_to_dst)
|
||||
# mlu
|
||||
mlu_ops.swap_blocks(dst, src, src_to_dst_tensor)
|
||||
# diff
|
||||
assert torch.allclose(src, ref_src)
|
||||
assert torch.allclose(dst, ref_dst)
|
||||
Reference in New Issue
Block a user