26 lines
1.1 KiB
C++
26 lines
1.1 KiB
C++
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
|
|
#include "kernels/preload.mluh"
|
|
#include "torch_ops_api.h"
|
|
|
|
namespace tmo {
|
|
namespace torch_api {
|
|
void preload(const torch::Tensor &weight, const int64_t size) {
|
|
const torch_mlu::mlu::MLUGuard device_guard(weight.device());
|
|
auto queue = torch_mlu::getCurMLUStream();
|
|
|
|
invokePreload(queue, weight.data_ptr(), weight.element_size() * weight.numel(), size);
|
|
}
|
|
} // namespace torch_api
|
|
} // namespace tmo
|