diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index e25d57b..7affe83 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -22,8 +22,6 @@ #include #include #include "acl/acl.h" -#include "tiling/platform/platform_ascendc.h" -#include "aclnn/opdev/platform.h" #include "ops.h" #include "utils.h" @@ -85,14 +83,13 @@ std::tuple rotary_embedding(at::Tensor &positions, at::T query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride, dst_key_stride, num_heads, num_kv_heads, head_size]() -> int { auto dtype_num = get_dtype_from_torch(scalar_type); - fe::PlatFormInfos platform_infos; int device_id = 0; - fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos); - uint32_t aivNum = platform_infos.GetCoreNumByType("aiv"); - uint32_t loop_cnt = (num_tokens + aivNum - 1) / aivNum; + int64_t aiv_num = 0; + TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); + uint32_t loop_cnt = (num_tokens + aiv_num - 1) / aiv_num; rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_dst_ptr, key_dst_ptr, query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, dst_query_stride, - dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum); + dst_key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aiv_num); return 0; }); cmd.Run(); @@ -177,13 +174,11 @@ std::tuple get_masked_input_and_mask( org_vocab_start_index, org_vocab_end_index, num_org_vocab_padding, added_vocab_start_index, added_vocab_end_index]() -> int { - // Get platform info - fe::PlatFormInfos platform_infos; int device_id = 0; - fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos); - uint32_t aivNum = platform_infos.GetCoreNumByType("aiv"); - uint32_t loop_cnt = (size + aivNum - 1) / aivNum; - + int64_t aiv_num = 0; + TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); + uint32_t loop_cnt = (size + aiv_num - 1) / aiv_num; + // Call implementation get_masked_input_and_mask_impl( stream, @@ -197,7 +192,7 @@ std::tuple get_masked_input_and_mask( added_vocab_end_index, size, loop_cnt, - aivNum); + aiv_num); return 0; }); diff --git a/tests/e2e/singlecard/ops/test_vocabparallelembedding.py b/tests/e2e/singlecard/ops/test_vocabparallelembedding.py index c533080..a8d7071 100644 --- a/tests/e2e/singlecard/ops/test_vocabparallelembedding.py +++ b/tests/e2e/singlecard/ops/test_vocabparallelembedding.py @@ -5,6 +5,9 @@ import torch import torch_npu # noqa: F401 import vllm_ascend.platform # noqa: F401 +from vllm_ascend.utils import enable_custom_op + +enable_custom_op() # Test parameters DTYPES = [torch.int32]