// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project #pragma once #include #include #include namespace vllm_mlu { torch::Tensor weak_ref_tensor(torch::Tensor& tensor) { // Ensure tensor is on MLU if (!tensor.is_privateuseone()) { throw std::runtime_error("Tensor must be on MLU device"); } // Get the raw data pointer void* data_ptr = tensor.data_ptr(); // Get tensor sizes and strides std::vector sizes = tensor.sizes().vec(); std::vector strides = tensor.strides().vec(); // Get tensor options (dtype, device) auto options = tensor.options(); // Create a new tensor from the raw data pointer auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options); return new_tensor; } }