Fuse sorted_token_ids padding to moe_align_block_size kernel (#7437)
This commit is contained in:
@@ -5,7 +5,11 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
try:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
except ImportError:
|
||||||
|
ops = None
|
||||||
|
|
||||||
USE_RANDOM_PERM = False
|
USE_RANDOM_PERM = False
|
||||||
|
|
||||||
@@ -208,7 +212,7 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
|
|||||||
)
|
)
|
||||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||||
vllm_works = True
|
vllm_works = True
|
||||||
except RuntimeError as e:
|
except Exception as e:
|
||||||
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
|
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
|
||||||
vllm_works = False
|
vllm_works = False
|
||||||
|
|
||||||
@@ -257,13 +261,47 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
|||||||
return topk_ids
|
return topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def sgl_moe_align_block_size_with_empty(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
pad_sorted_token_ids=False,
|
||||||
|
):
|
||||||
|
if not pad_sorted_token_ids:
|
||||||
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
|
|
||||||
|
token_cnts_buffer = torch.empty(
|
||||||
|
(num_experts + 1) * num_experts,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device,
|
||||||
|
)
|
||||||
|
cumsum_buffer = torch.empty(
|
||||||
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
|
sgl_moe_align_block_size(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids.clone(),
|
||||||
|
expert_ids.clone(),
|
||||||
|
num_tokens_post_pad.clone(),
|
||||||
|
token_cnts_buffer,
|
||||||
|
cumsum_buffer,
|
||||||
|
pad_sorted_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
@triton.testing.perf_report(
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["num_tokens", "num_experts", "topk"],
|
x_names=["num_tokens", "num_experts", "topk"],
|
||||||
x_vals=configs,
|
x_vals=configs,
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["sgl", "triton", "vllm"],
|
line_vals=["sgl", "sgl_fusion", "triton"],
|
||||||
line_names=["SGL", "Triton", "VLLM"],
|
line_names=["SGL", "SGL Fusion", "Triton"],
|
||||||
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
|
||||||
ylabel="us",
|
ylabel="us",
|
||||||
plot_name="moe-align-block-size-performance",
|
plot_name="moe-align-block-size-performance",
|
||||||
@@ -288,7 +326,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
sorted_ids = torch.empty(
|
sorted_ids = torch.empty(
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
|
||||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||||
expert_ids = torch.empty(
|
expert_ids = torch.empty(
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
@@ -297,35 +334,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
if provider == "sgl":
|
if provider == "sgl":
|
||||||
|
|
||||||
def sgl_moe_align_block_size_with_empty(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_pad,
|
|
||||||
):
|
|
||||||
token_cnts_buffer = torch.empty(
|
|
||||||
(num_experts + 1) * num_experts,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device,
|
|
||||||
)
|
|
||||||
cumsum_buffer = torch.empty(
|
|
||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
|
|
||||||
sgl_moe_align_block_size(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_ids.clone(),
|
|
||||||
expert_ids.clone(),
|
|
||||||
num_tokens_post_pad.clone(),
|
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
lambda: sgl_moe_align_block_size_with_empty(
|
lambda: sgl_moe_align_block_size_with_empty(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -337,7 +345,21 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
|
elif provider == "sgl_fusion":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: sgl_moe_align_block_size_with_empty(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
pad_sorted_token_ids=True,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
elif provider == "triton":
|
elif provider == "triton":
|
||||||
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
lambda: moe_align_block_size_triton(
|
lambda: moe_align_block_size_triton(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -349,23 +371,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
|||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
else: # vllm
|
|
||||||
try:
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: ops.moe_align_block_size(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_ids.clone(),
|
|
||||||
expert_ids.clone(),
|
|
||||||
num_tokens_post_pad.clone(),
|
|
||||||
),
|
|
||||||
quantiles=quantiles,
|
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
|
||||||
print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}")
|
|
||||||
# Return extreme values to indicate failure in the chart
|
|
||||||
return float("inf"), float("inf"), float("inf")
|
|
||||||
|
|
||||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|||||||
@@ -160,7 +160,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
*/
|
*/
|
||||||
m.def(
|
m.def(
|
||||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
|
||||||
|
"pad_sorted_token_ids) -> ()");
|
||||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
|
|||||||
@@ -21,8 +21,17 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
|
||||||
|
template <typename T, int N, int Alignment = sizeof(T) * N>
|
||||||
|
class alignas(Alignment) AlignedArray {
|
||||||
|
public:
|
||||||
|
T data[N];
|
||||||
|
};
|
||||||
|
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
|
|
||||||
|
#define VEC_SIZE 4
|
||||||
|
using Vec = AlignedArray<int32_t, VEC_SIZE>;
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void count_and_sort_expert_tokens_kernel(
|
__global__ void count_and_sort_expert_tokens_kernel(
|
||||||
const scalar_t* __restrict__ topk_ids,
|
const scalar_t* __restrict__ topk_ids,
|
||||||
@@ -50,7 +59,8 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
int32_t experts_per_warp,
|
int32_t experts_per_warp,
|
||||||
int32_t block_size,
|
int32_t block_size,
|
||||||
size_t numel,
|
size_t numel,
|
||||||
int32_t* __restrict__ cumsum) {
|
int32_t* __restrict__ cumsum,
|
||||||
|
bool pad_sorted_token_ids) {
|
||||||
extern __shared__ int32_t shared_counts[];
|
extern __shared__ int32_t shared_counts[];
|
||||||
|
|
||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
@@ -96,6 +106,24 @@ __global__ void moe_align_block_size_kernel(
|
|||||||
expert_ids[i / block_size] = threadIdx.x;
|
expert_ids[i / block_size] = threadIdx.x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (pad_sorted_token_ids) {
|
||||||
|
int32_t fill_val = static_cast<int32_t>(numel);
|
||||||
|
int32_t total = *total_tokens_post_pad;
|
||||||
|
|
||||||
|
Vec fill_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||||
|
fill_vec.data[i] = fill_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE;
|
||||||
|
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
|
||||||
|
|
||||||
|
for (int32_t idx = tid; idx < total_vec_count; idx += stride) {
|
||||||
|
out_ptr[idx] = fill_vec;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
@@ -106,7 +134,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
|||||||
int32_t* __restrict__ total_tokens_post_pad,
|
int32_t* __restrict__ total_tokens_post_pad,
|
||||||
int32_t num_experts,
|
int32_t num_experts,
|
||||||
int32_t block_size,
|
int32_t block_size,
|
||||||
size_t numel) {
|
size_t numel,
|
||||||
|
bool pad_sorted_token_ids) {
|
||||||
const size_t tid = threadIdx.x;
|
const size_t tid = threadIdx.x;
|
||||||
const size_t stride = blockDim.x;
|
const size_t stride = blockDim.x;
|
||||||
|
|
||||||
@@ -149,6 +178,26 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (pad_sorted_token_ids) {
|
||||||
|
int32_t fill_val = static_cast<int32_t>(numel);
|
||||||
|
int32_t total = *total_tokens_post_pad;
|
||||||
|
|
||||||
|
Vec fill_vec;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||||
|
fill_vec.data[i] = fill_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t total_vec_count = (total + VEC_SIZE - 1) / VEC_SIZE;
|
||||||
|
Vec* out_ptr = reinterpret_cast<Vec*>(sorted_token_ids);
|
||||||
|
|
||||||
|
for (int32_t idx = tid; idx < total_vec_count; idx += stride) {
|
||||||
|
out_ptr[idx] = fill_vec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
for (size_t i = tid; i < numel; i += stride) {
|
for (size_t i = tid; i < numel; i += stride) {
|
||||||
int32_t expert_id = topk_ids[i];
|
int32_t expert_id = topk_ids[i];
|
||||||
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
|
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
|
||||||
@@ -165,7 +214,8 @@ void moe_align_block_size(
|
|||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad,
|
torch::Tensor num_tokens_post_pad,
|
||||||
torch::Tensor token_cnts_buffer,
|
torch::Tensor token_cnts_buffer,
|
||||||
torch::Tensor cumsum_buffer) {
|
torch::Tensor cumsum_buffer,
|
||||||
|
bool pad_sorted_token_ids) {
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||||
@@ -190,7 +240,8 @@ void moe_align_block_size(
|
|||||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
topk_ids.numel());
|
topk_ids.numel(),
|
||||||
|
pad_sorted_token_ids);
|
||||||
} else {
|
} else {
|
||||||
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
|
||||||
|
|
||||||
@@ -207,7 +258,8 @@ void moe_align_block_size(
|
|||||||
experts_per_warp,
|
experts_per_warp,
|
||||||
block_size,
|
block_size,
|
||||||
topk_ids.numel(),
|
topk_ids.numel(),
|
||||||
cumsum_buffer.data_ptr<int32_t>());
|
cumsum_buffer.data_ptr<int32_t>(),
|
||||||
|
pad_sorted_token_ids);
|
||||||
|
|
||||||
const int block_threads = std::min(256, (int)threads);
|
const int block_threads = std::min(256, (int)threads);
|
||||||
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
|
||||||
|
|||||||
@@ -59,7 +59,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
*/
|
*/
|
||||||
m.def(
|
m.def(
|
||||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
|
||||||
|
"pad_sorted_token_ids) -> ()");
|
||||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
|
|||||||
@@ -212,7 +212,8 @@ void moe_align_block_size(
|
|||||||
torch::Tensor experts_ids,
|
torch::Tensor experts_ids,
|
||||||
torch::Tensor num_tokens_post_pad,
|
torch::Tensor num_tokens_post_pad,
|
||||||
torch::Tensor token_cnts_buffer,
|
torch::Tensor token_cnts_buffer,
|
||||||
torch::Tensor cumsum_buffer);
|
torch::Tensor cumsum_buffer,
|
||||||
|
bool pad_sorted_token_ids);
|
||||||
|
|
||||||
void topk_softmax(
|
void topk_softmax(
|
||||||
torch::Tensor& topk_weights,
|
torch::Tensor& topk_weights,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ def moe_align_block_size(
|
|||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
token_cnts_buffer,
|
token_cnts_buffer,
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
|
pad_sorted_token_ids=False,
|
||||||
):
|
):
|
||||||
torch.ops.sgl_kernel.moe_align_block_size.default(
|
torch.ops.sgl_kernel.moe_align_block_size.default(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -22,6 +23,7 @@ def moe_align_block_size(
|
|||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
token_cnts_buffer,
|
token_cnts_buffer,
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
|
pad_sorted_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -138,33 +138,32 @@ def moe_align_block_size_triton(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"block_size,num_tokens,topk,num_experts",
|
"block_size,num_tokens,topk,num_experts,pad_sorted_token_ids",
|
||||||
list(
|
list(
|
||||||
itertools.product(
|
itertools.product(
|
||||||
[32, 64, 128, 256], # block_size
|
[32, 64, 128, 256], # block_size
|
||||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
||||||
[1, 2, 4, 8, 16, 32, 64], # topk
|
[1, 2, 4, 8, 16, 32, 64], # topk
|
||||||
[64, 160, 256, 257, 260, 264], # num_experts
|
[64, 160, 256, 257, 260, 264], # num_experts
|
||||||
|
[True, False], # pad_sorted_token_ids
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_moe_align_block_size_compare_implementations(
|
def test_moe_align_block_size_compare_implementations(
|
||||||
block_size, num_tokens, topk, num_experts
|
block_size, num_tokens, topk, num_experts, pad_sorted_token_ids
|
||||||
):
|
):
|
||||||
|
|
||||||
topk_ids = torch.stack(
|
topk_ids = torch.argsort(torch.rand(num_tokens, num_experts, device="cuda"), dim=1)[
|
||||||
[
|
:, :topk
|
||||||
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
]
|
||||||
for _ in range(num_tokens)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
|
|
||||||
sorted_ids_cuda = torch.empty(
|
sorted_ids_cuda = torch.empty(
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
sorted_ids_cuda.fill_(topk_ids.numel())
|
if not pad_sorted_token_ids:
|
||||||
|
sorted_ids_cuda.fill_(topk_ids.numel())
|
||||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||||
expert_ids_cuda = torch.zeros(
|
expert_ids_cuda = torch.zeros(
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
@@ -195,6 +194,7 @@ def test_moe_align_block_size_compare_implementations(
|
|||||||
num_tokens_post_pad_cuda,
|
num_tokens_post_pad_cuda,
|
||||||
token_cnts_buffer,
|
token_cnts_buffer,
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
|
pad_sorted_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
moe_align_block_size_triton(
|
moe_align_block_size_triton(
|
||||||
@@ -206,20 +206,51 @@ def test_moe_align_block_size_compare_implementations(
|
|||||||
num_tokens_post_pad_triton,
|
num_tokens_post_pad_triton,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch.allclose(expert_ids_cuda, expert_ids_triton), (
|
assert torch.allclose(expert_ids_cuda, expert_ids_triton, atol=0, rtol=0), (
|
||||||
f"Expert IDs mismatch for block_size={block_size}, "
|
f"Expert IDs mismatch for block_size={block_size}, "
|
||||||
f"num_tokens={num_tokens}, topk={topk}\n"
|
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||||
f"CUDA expert_ids: {expert_ids_cuda}\n"
|
f"CUDA expert_ids: {expert_ids_cuda}\n"
|
||||||
f"Triton expert_ids: {expert_ids_triton}"
|
f"Triton expert_ids: {expert_ids_triton}"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton), (
|
assert torch.allclose(
|
||||||
|
num_tokens_post_pad_cuda, num_tokens_post_pad_triton, atol=0, rtol=0
|
||||||
|
), (
|
||||||
f"Num tokens post pad mismatch for block_size={block_size}, "
|
f"Num tokens post pad mismatch for block_size={block_size}, "
|
||||||
f"num_tokens={num_tokens}, topk={topk}\n"
|
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||||
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
|
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
|
||||||
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}"
|
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Select an expert to check
|
||||||
|
expert_idx = expert_ids_cuda.max().item()
|
||||||
|
|
||||||
|
# Get the first and last block id where expert_ids_cuda == expert_idx
|
||||||
|
matching_indices = torch.where(expert_ids_cuda == expert_idx)[0]
|
||||||
|
block_sorted_start = matching_indices[0].item() * block_size
|
||||||
|
block_sorted_end = min(
|
||||||
|
(matching_indices[-1].item() + 1) * block_size, max_num_tokens_padded
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_sorted_ids_cuda = sorted_ids_cuda[
|
||||||
|
block_sorted_start:block_sorted_end
|
||||||
|
].sort()[0]
|
||||||
|
selected_sorted_ids_triton = sorted_ids_triton[
|
||||||
|
block_sorted_start:block_sorted_end
|
||||||
|
].sort()[0]
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
selected_sorted_ids_cuda,
|
||||||
|
selected_sorted_ids_triton,
|
||||||
|
atol=0,
|
||||||
|
rtol=0,
|
||||||
|
), (
|
||||||
|
f"Sorted IDs mismatch for block_size={block_size}, "
|
||||||
|
f"num_tokens={num_tokens}, topk={topk}\n"
|
||||||
|
f"CUDA sorted_ids: {selected_sorted_ids_cuda}\n"
|
||||||
|
f"Triton sorted_ids: {selected_sorted_ids_triton}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
Reference in New Issue
Block a user