diff --git a/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.cpp b/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.cpp index d9683524..8a977f2e 100644 --- a/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.cpp +++ b/csrc/apply_top_k_top_p_custom/op_host/aclnn_apply_top_k_top_p_custom.cpp @@ -38,7 +38,7 @@ static const int64_t EXPECTED_DIM_ONE = 1; static const int64_t EXPECTED_DIM_TWO = 2; static constexpr size_t DIM_ONE = 1; -// 根据API定义,需要列出所能支持的所有dtype +// According to the API definition, all supported dtypes must be enumerated. static const std::initializer_list DTYPE_SUPPORT_LIST = { op::DataType::DT_FLOAT, op::DataType::DT_FLOAT16, op::DataType::DT_BF16}; @@ -57,7 +57,7 @@ static bool CheckNotNull(const aclTensor* logits, const aclTensor* p, const aclT static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out) { - // 检查数据类型是否在支持列表内 + // Check whether the data type is within the supported list. OP_CHECK_DTYPE_NOT_SUPPORT(logits, DTYPE_SUPPORT_LIST, return false); if (p != nullptr) { OP_CHECK_DTYPE_NOT_SUPPORT(p, DTYPE_SUPPORT_LIST, return false); @@ -67,7 +67,7 @@ static bool CheckDtypeValid(const aclTensor* logits, const aclTensor* p, const a } OP_CHECK_DTYPE_NOT_SUPPORT(out, DTYPE_SUPPORT_LIST, return false); - // 检查数据类型是否相同 + // Check whether the data types are identical. if (p != nullptr) { OP_CHECK_DTYPE_NOT_MATCH(p, logits->GetDataType(), return false); } @@ -121,17 +121,17 @@ static bool CheckFormatValid(const aclTensor* logits, const aclTensor* p, const static aclnnStatus CheckParams(const aclTensor* logits, const aclTensor* p, const aclTensor *k, const aclTensor* out) { - // 错误码等DFX方案细化后刷新,错误日志在check接口内打印 - // 1. 检查参数是否为空指针 + // Refresh after the DFX scheme for error codes, etc. is refined; error logs are printed inside the check interfaces. + // 1. Check whether any parameters are null pointers. CHECK_RET(CheckNotNull(logits, p, k, out), ACLNN_ERR_PARAM_NULLPTR); - // 2. 检查输入的数据类型是否在API支持的数据类型范围之内,需要根据api定义校验 + // 2. Check whether the input data types are within the range supported by the API; validate according to the API definition. CHECK_RET(CheckDtypeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID); - // 3. 检查shape是否满足约束 + // 3. Check whether shapes satisfy the constraints. CHECK_RET(CheckShapeValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID); - // 4. 检查format是否满足约束 + // 4. Check whether formats satisfy the constraints. CHECK_RET(CheckFormatValid(logits, p, k, out), ACLNN_ERR_PARAM_INVALID); return ACLNN_SUCCESS; @@ -144,11 +144,11 @@ aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize( { OP_CHECK_COMM_INPUT(workspaceSize, executor); L2_DFX_PHASE_1(aclnnApplyTopKTopPCustom, DFX_IN(logits, p, k), DFX_OUT(out)); - // 固定写法,创建OpExecutor + // Fixed boilerplate: create OpExecutor. auto uniqueExecutor = CREATE_EXECUTOR(); CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR); - // 固定写法,参数检查 + // Fixed boilerplate: parameter validation. auto ret = CheckParams(logits, p, k, out); CHECK_RET(ret == ACLNN_SUCCESS, ret); bool pIsEmpty = false; @@ -160,12 +160,12 @@ aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize( kIsEmpty = k->IsEmpty(); } if (logits->IsEmpty() || pIsEmpty || kIsEmpty) { - // 根据实际支持情况补充 + // Supplement according to actual support status. *workspaceSize = 0; uniqueExecutor.ReleaseTo(executor); return ACLNN_SUCCESS; } - // 固定写法,将输入selfRef转换成连续的tensor + // Fixed boilerplate: convert the input selfRef to a contiguous tensor. auto logitsContiguous = l0op::Contiguous(logits, uniqueExecutor.get()); CHECK_RET(logitsContiguous != nullptr, ACLNN_ERR_INNER_NULLPTR); const aclTensor* pContiguous = nullptr; @@ -190,24 +190,23 @@ aclnnStatus aclnnApplyTopKTopPCustomGetWorkspaceSize( CHECK_RET(sortedIndices != nullptr, ACLNN_ERR_INNER_NULLPTR); auto res = l0op::ApplyTopKTopPCustom(sortedValue, sortedIndices, pContiguous, kContiguous, uniqueExecutor.get()); CHECK_RET(res != nullptr, ACLNN_ERR_INNER_NULLPTR); - // 固定写法,将计算结果拷贝到输出out上,out可能是非连续的tensor + // Fixed boilerplate: copy the computed result to the output 'out'; 'out' may be a non-contiguous tensor. viewCopyResult = l0op::ViewCopy(res, out, uniqueExecutor.get()); } CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR); - // 固定写法,获取计算过程中需要使用的workspace大小 + // Fixed boilerplate: obtain the workspace size required during computation. *workspaceSize = uniqueExecutor->GetWorkspaceSize(); - uniqueExecutor.ReleaseTo(executor); // 需要把 uniqueExecutor持有executor转移给executor + uniqueExecutor.ReleaseTo(executor); // Transfer ownership of the executor held by uniqueExecutor to executor. return ACLNN_SUCCESS; } aclnnStatus aclnnApplyTopKTopPCustom(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream) { L2_DFX_PHASE_2(aclnnApplyTopKTopPCustom); - // 固定写法,调用框架能力,完成计算 + // Fixed boilerplate: invoke framework capability to complete the computation. return CommonOpExecutorRun(workspace, workspaceSize, executor, stream); } #ifdef __cplusplus } #endif - diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_apply_top_k_top_p_custom.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_apply_top_k_top_p_custom.py new file mode 100644 index 00000000..3a153618 --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_apply_top_k_top_p_custom.py @@ -0,0 +1,139 @@ +import numpy as np +import pytest +import torch +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() + + +def cpu_op_exec(logits, p, k): + """ + Apply top-k and top-p sampling filtering. + """ + # Sort logits in ascending order + logits_sort, logits_idx = logits.sort(dim=-1, descending=False, stable=True) + + # 1. Apply top-k filtering + if k is not None: + # Ensure k does not exceed vocab_size + k = torch.minimum(k, torch.tensor(logits.size(-1), device=k.device)) + top_k_mask_idx = logits_sort.size(1) - k.to(torch.long) + top_k_threshold = logits_sort.gather(1, top_k_mask_idx.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_threshold + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + # 2. Apply top-p (nucleus) filtering + if p is not None: + probs_sort = logits_sort.to(torch.float32).softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # 3. Restore original order + logits = torch.empty_like(logits_sort).scatter_(dim=-1, index=logits_idx, src=logits_sort) + return logits + + +def cpu_op_exec_top_k(logits, p, k): + return cpu_op_exec(logits, None, k) + + +def cpu_op_exec_top_p(logits, p, k): + return cpu_op_exec(logits, p, None) + + +def ascendc_op_exec(logits, p, k): + """ + Execute the custom Ascend NPU operator. + """ + logits_npu = logits.npu() + p_npu = p.npu() if p is not None else None + k_npu = k.npu() if k is not None else None + + return torch.ops._C_ascend.npu_apply_top_k_top_p(logits_npu, k=k_npu, p=p_npu).cpu() + + +def assert_output_close(out_cpu, out_npu, rtol=1e-4, atol=1e-4): + """ + Custom assertion to handle Top-P boundary precision issues. + """ + # 1. Check mask consistency (inf vs finite) + mask_cpu = torch.isinf(out_cpu) & (out_cpu < 0) + mask_npu = torch.isinf(out_npu) & (out_npu < 0) + + mismatch_mask = mask_cpu ^ mask_npu + mismatch_count = mismatch_mask.sum().item() + total_elements = out_cpu.numel() + + # Allow 0.1% mismatch for boundary floating point precision differences + mismatch_ratio = mismatch_count / total_elements + if mismatch_ratio > 0.001: + pytest.fail(f"Mask mismatch ratio too high: {mismatch_ratio:.6f} ({mismatch_count}/{total_elements})") + + # 2. Check value consistency for valid elements + valid_mask = (~mask_cpu) & (~mask_npu) + if valid_mask.any(): + torch.testing.assert_close( + out_cpu[valid_mask], + out_npu[valid_mask], + rtol=rtol, + atol=atol + ) + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize('vocab_size', [15206, 152064]) +@pytest.mark.parametrize('batch_size', [4, 8, 16, 32, 64, 96, 128, 256]) +@pytest.mark.parametrize('p_val', [0.5, 0.9, 0.99]) +@pytest.mark.parametrize('k_val', [50, 200, 1024, 4096, 8192]) +def test_npu_apply_top_k_top_p(vocab_size, batch_size, p_val, k_val): + shape = [batch_size, vocab_size] + dtype = torch.float32 + + logits = torch.from_numpy(np.random.uniform(-5, 5, shape)).to(dtype) + p = torch.full((batch_size,), p_val, dtype=dtype) + k = torch.full((batch_size,), k_val, dtype=torch.int32) + + out_cpu = cpu_op_exec(logits.clone(), p, k) + out_npu = ascendc_op_exec(logits, p, k) + + assert_output_close(out_cpu, out_npu) + + +@pytest.mark.parametrize('vocab_size', [15206, 152064]) +@pytest.mark.parametrize('batch_size', [4, 8, 16, 32, 64, 96, 128, 256]) +@pytest.mark.parametrize('k_val', [50, 200, 1024, 4096, 8192]) +def test_npu_apply_top_k(vocab_size, batch_size, k_val): + shape = [batch_size, vocab_size] + dtype = torch.float32 + + logits = torch.from_numpy(np.random.uniform(-5, 5, shape)).to(dtype) + p = None + k = torch.full((batch_size,), k_val, dtype=torch.int32) + + out_cpu = cpu_op_exec_top_k(logits.clone(), p, k) + out_npu = ascendc_op_exec(logits, p, k) + + assert_output_close(out_cpu, out_npu) + + +@pytest.mark.parametrize('vocab_size', [15206, 152064]) +@pytest.mark.parametrize('batch_size', [4, 8, 16, 32, 64, 96, 128, 256]) +@pytest.mark.parametrize('p_val', [0.5, 0.9, 0.99]) +def test_npu_apply_top_p(vocab_size, batch_size, p_val): + shape = [batch_size, vocab_size] + dtype = torch.float32 + + logits = torch.from_numpy(np.random.uniform(-5, 5, shape)).to(dtype) + p = torch.full((batch_size,), p_val, dtype=dtype) + k = None + + out_cpu = cpu_op_exec_top_p(logits.clone(), p, k) + out_npu = ascendc_op_exec(logits, p, k) + + assert_output_close(out_cpu, out_npu)