Sync from v0.13
This commit is contained in:
96
tests/lora/test_moe_lora_align_sum.py
Normal file
96
tests/lora/test_moe_lora_align_sum.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
def round_up(x, base):
|
||||
return ((x + base - 1) // base) * base
|
||||
|
||||
|
||||
def CEILDIV(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def sample_data(num_experts, max_loras, num_tokens, topk_num):
|
||||
topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32)
|
||||
token_lora_mapping = torch.zeros((num_tokens,), dtype=torch.int32)
|
||||
|
||||
for i in range(num_tokens):
|
||||
pool = list(range(num_experts))
|
||||
random.shuffle(pool)
|
||||
for j in range(topk_num):
|
||||
topk_ids[i, j] = pool[j]
|
||||
token_lora_mapping[i] = random.randint(0, max_loras - 1)
|
||||
|
||||
return topk_ids.to("cuda"), token_lora_mapping.to("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [100, 200, 1024, 4096]) # 81920
|
||||
@pytest.mark.parametrize("topk_num", [6])
|
||||
@pytest.mark.parametrize("num_experts", [64, 128, 256, 512])
|
||||
@pytest.mark.parametrize("max_loras", [2, 32])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_moe_lora_align_block_size(
|
||||
num_tokens, topk_num, num_experts, max_loras, block_size
|
||||
):
|
||||
# sample data
|
||||
random.seed(1)
|
||||
topk_ids, token_lora_mapping = sample_data(
|
||||
num_experts, max_loras, num_tokens, topk_num
|
||||
)
|
||||
|
||||
# compute paddings
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
|
||||
|
||||
# init output tensors
|
||||
sorted_token_ids = torch.full(
|
||||
(max_loras * max_num_tokens_padded,),
|
||||
topk_ids.numel(),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
expert_ids = torch.full(
|
||||
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
|
||||
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
|
||||
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")
|
||||
|
||||
# call kernel
|
||||
ops.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
token_lora_mapping,
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
max_num_tokens_padded,
|
||||
max_num_m_blocks,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
adapter_enabled,
|
||||
lora_ids,
|
||||
)
|
||||
|
||||
# verify values
|
||||
expert_ids = expert_ids.view(max_loras, -1)
|
||||
sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size)
|
||||
|
||||
for lora_idx in range(max_loras):
|
||||
for token_idx in range(sorted_token_ids.size(1)):
|
||||
block = sorted_token_ids[lora_idx][token_idx]
|
||||
indices = block[block != topk_ids.numel()]
|
||||
if indices.numel() > 0:
|
||||
expert_id = expert_ids[lora_idx][token_idx]
|
||||
assert torch.all(topk_ids.view(-1)[indices] == expert_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user