diff --git a/benchmarks/ops/ben_vocabparallelembedding.py b/benchmarks/ops/ben_vocabparallelembedding.py index b3ef7ec..5590c73 100644 --- a/benchmarks/ops/ben_vocabparallelembedding.py +++ b/benchmarks/ops/ben_vocabparallelembedding.py @@ -112,7 +112,7 @@ def test_get_masked_input_and_mask( # Define custom function def custom_fn(): - return torch.ops._C.get_masked_input_and_mask( + return torch.ops._C_ascend.get_masked_input_and_mask( input_tensor, test_case["org_start"], test_case["org_end"], diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 1291a39..5dd6988 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -141,7 +141,7 @@ std::tuple get_masked_input_and_mask( TP2, rank 1: |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | - index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | Parameters: org_vocab_start_index //base embeddings start org_vocab_end_index //base embeddings end @@ -164,22 +164,22 @@ std::tuple get_masked_input_and_mask( // Create output tensors at::Tensor masked_input = at::empty_like(input); at::Tensor mask = at::empty_like(input).to(at::kBool); - + // Get data pointers void *input_ptr = input.data_ptr(); void *masked_input_ptr = masked_input.data_ptr(); void *mask_ptr = mask.data_ptr(); - + // Get current stream aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); - + // Get scalar type at::ScalarType scalar_type = input.scalar_type(); - + // Create and configure OpCommand at_npu::native::OpCommand cmd; cmd.Name("get_masked_input_and_mask"); - cmd.SetCustomHandler([scalar_type, size, stream, + cmd.SetCustomHandler([scalar_type, size, stream, input_ptr, masked_input_ptr, mask_ptr, org_vocab_start_index, org_vocab_end_index, num_org_vocab_padding, added_vocab_start_index, @@ -193,7 +193,7 @@ std::tuple get_masked_input_and_mask( get_masked_input_and_mask_impl( stream, input_ptr, - masked_input_ptr, + masked_input_ptr, mask_ptr, org_vocab_start_index, org_vocab_end_index, @@ -203,7 +203,7 @@ std::tuple get_masked_input_and_mask( size, loop_cnt, aiv_num); - + return 0; }); cmd.Run(); @@ -320,8 +320,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("sgmv_shrink"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, - seq_len_ptr, seq_len_size, y_ptr, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, + seq_len_ptr, seq_len_size, y_ptr, batch_size, input_hidden_token, lora_rank, scale_f]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -330,7 +330,7 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, - y_ptr, batch_size, + y_ptr, batch_size, num_tokens_per_core, input_hidden_token, lora_rank, scale_f); return 0; }); @@ -367,7 +367,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); at_npu::native::OpCommand cmd; cmd.Name("sgmv_expand"); - cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, + cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int { auto dtype = get_dtype_from_torch(scalar_type); int device_id = 0; @@ -375,7 +375,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS); int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num; TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0"); - sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, + sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr, batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim); return 0; }); @@ -384,7 +384,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic } } // namespace vllm_ascend -TORCH_LIBRARY_EXPAND(_C, ops) +TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) { // vLLM-Ascend custom ops ops.def("weak_ref_tensor(Tensor input) -> Tensor"); diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index d69254b..4101ee7 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -40,7 +40,7 @@ std::tuple rotary_embedding_meta( at::Tensor &positions, at::Tensor &query, at::Tensor &key, - int64_t head_size, + int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) { auto num_tokens = positions.sym_numel(); @@ -86,9 +86,9 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_ } // namespace vllm_ascend namespace { - // Register the meta implementations of the custom kernels for symbolic tracing, this will also + // Register the meta implementations of the custom kernels for symbolic tracing, this will also // the custom kernel been captured into aclgraph - TORCH_LIBRARY_IMPL_EXPAND(_C, Meta, ops) { + TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { // Rotary embedding meta implementation ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta); // Masked input and mask meta implementation @@ -99,4 +99,4 @@ namespace { ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta); } -} \ No newline at end of file +} diff --git a/tests/e2e/singlecard/ops/test_bgmv_expand.py b/tests/e2e/singlecard/ops/test_bgmv_expand.py index 0aca9ca..9d82ab8 100644 --- a/tests/e2e/singlecard/ops/test_bgmv_expand.py +++ b/tests/e2e/singlecard/ops/test_bgmv_expand.py @@ -33,8 +33,8 @@ def test_bgmv_expand(): y_npu = y.npu() y_out = bgmv_expand_cpu_impl(x, w, indices, y, 0, 128) - y_out_npu = torch.ops._C.bgmv_expand(x_npu, w_npu, indices_npu, y_npu, 0, - 128) + y_out_npu = torch.ops._C_ascend.bgmv_expand(x_npu, w_npu, indices_npu, + y_npu, 0, 128) # Compare the results. torch.testing.assert_close(y_out_npu.cpu(), diff --git a/tests/e2e/singlecard/ops/test_bgmv_shrink.py b/tests/e2e/singlecard/ops/test_bgmv_shrink.py index 99bb8e8..6cb8127 100644 --- a/tests/e2e/singlecard/ops/test_bgmv_shrink.py +++ b/tests/e2e/singlecard/ops/test_bgmv_shrink.py @@ -33,7 +33,7 @@ def test_bgmv_shrink(): y_npu = y.npu() y = bgmv_shrink_cpu_impl(x, w, indices, y, 0.5) - torch.ops._C.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5) + torch.ops._C_ascend.bgmv_shrink(x_npu, w_npu, indices_npu, y_npu, 0.5) # Compare the results. torch.testing.assert_close(y_npu.cpu(), diff --git a/tests/e2e/singlecard/ops/test_rotary_embedding.py b/tests/e2e/singlecard/ops/test_rotary_embedding.py index 6f513b2..27e9b3b 100644 --- a/tests/e2e/singlecard/ops/test_rotary_embedding.py +++ b/tests/e2e/singlecard/ops/test_rotary_embedding.py @@ -182,7 +182,7 @@ def test_rotary_embedding_quant_with_leading_dim( ) ref_query, ref_key = rope.forward_native(positions, query, key) - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, query, key, @@ -239,7 +239,7 @@ class ModelwithRotaryEmbedding(nn.Module): # we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph qkv = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(3, dim=-1) - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, q, k, @@ -299,7 +299,7 @@ def test_capture_rotary_embedding_in_aclgraph( # Validate if the rotary_embedding custom kernel is indeed inside the graph by # string match graph = str(gm.graph) - assert "_C.rotary_embedding" in graph + assert "_C_ascend.rotary_embedding" in graph return gm static_positions = torch.randint(0, max_position_embeddings, diff --git a/tests/e2e/singlecard/ops/test_vocabparallelembedding.py b/tests/e2e/singlecard/ops/test_vocabparallelembedding.py index 54d1127..64b974d 100644 --- a/tests/e2e/singlecard/ops/test_vocabparallelembedding.py +++ b/tests/e2e/singlecard/ops/test_vocabparallelembedding.py @@ -72,7 +72,7 @@ def test_get_masked_input_and_mask( # Get custom op result print("input_tensor:", input_tensor) - custom_masked_input, custom_mask = torch.ops._C.get_masked_input_and_mask( + custom_masked_input, custom_mask = torch.ops._C_ascend.get_masked_input_and_mask( input_tensor, test_case["org_start"], test_case["org_end"], test_case["padding"], test_case["added_start"], test_case["added_end"]) diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index de6f4ef..21d95bb 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -94,7 +94,7 @@ class TestAscendRotaryEmbedding(unittest.TestCase): self.mock_self.cos_sin_cache = self.cos_sin_cache self.mock_self.is_neox_style = self.is_neox_style - @patch('torch.ops._C') + @patch('torch.ops._C_ascend') @patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False) @patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled', return_value=True) diff --git a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py index ce74dee..4adb598 100644 --- a/tests/ut/torchair/ops/test_torchair_rotary_embedding.py +++ b/tests/ut/torchair/ops/test_torchair_rotary_embedding.py @@ -104,7 +104,7 @@ class TestRopeForwardOot(TestBase): self.assertTrue(torch.equal(result_q, self.query)) self.assertTrue(torch.equal(result_k, self.key)) - @patch('torch.ops._C') + @patch('torch.ops._C_ascend') @patch( 'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config') @patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p', diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index f8dfc24..cc12448 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -15,7 +15,8 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.platforms import current_platform -from vllm.utils import weak_ref_tensors + +from ..utils import weak_ref_tensors @dataclasses.dataclass @@ -35,10 +36,10 @@ class ACLGraphWrapper: The workflow of this wrapper in the aclgraph dispatching is as follows: 1. At initialization, a runtime mode is assigned to the wrapper (FULL or - PIECEWISE). - 2. At runtime, the wrapper receives a runtime_mode and a + PIECEWISE). + 2. At runtime, the wrapper receives a runtime_mode and a batch_descriptor(key) from the forward context and blindly trust them - for aclgraph dispatching. + for aclgraph dispatching. 3. If runtime_mode is NONE or runtime_mode does not match the mode of the wrapper, just call the runnable directly. 4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, @@ -47,9 +48,9 @@ class ACLGraphWrapper: Note: ACLGraphWrapper does not store persistent buffers or copy any runtime inputs into that buffers for replay. We assume implementing them - is done outside of the wrapper. That is because we do not make any + is done outside of the wrapper. That is because we do not make any assumption on the dynamic shape (batch size) of the runtime inputs, as a - trade-off for staying orthogonal to compilation logic. Nevertheless, + trade-off for staying orthogonal to compilation logic. Nevertheless, tracing and checking the input addresses to be consistent during replay is guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". """ diff --git a/vllm_ascend/lora/lora_ops.py b/vllm_ascend/lora/lora_ops.py index e8bf8ad..58d0ea6 100644 --- a/vllm_ascend/lora/lora_ops.py +++ b/vllm_ascend/lora/lora_ops.py @@ -21,7 +21,7 @@ def bgmv_shrink(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - return torch.ops._C.bgmv_shrink( + return torch.ops._C_ascend.bgmv_shrink( inputs, lora_a_weights, lora_indices_tensor, @@ -35,7 +35,7 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): - return torch.ops._C.bgmv_expand( + return torch.ops._C_ascend.bgmv_expand( inputs, lora_b_weights, lora_indices_tensor, @@ -52,9 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - return torch.ops._C.bgmv_expand(inputs, lora_b_weights, - lora_indices_tensor, output_tensor, - slice_offset, slice_size) + return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, output_tensor, + slice_offset, slice_size) def sgmv_shrink( @@ -69,9 +69,9 @@ def sgmv_shrink( token_nums: int, scaling: float, ): - return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, - lora_indices_tensor, seq_len_tensor, - output_tensor, scaling) + return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, scaling) def sgmv_expand(inputs: torch.Tensor, @@ -84,7 +84,7 @@ def sgmv_expand(inputs: torch.Tensor, max_seq_length: int, token_nums: int, add_inputs: bool = False): - return torch.ops._C.sgmv_expand( + return torch.ops._C_ascend.sgmv_expand( inputs, lora_b_weights, lora_indices_tensor, @@ -107,6 +107,7 @@ def sgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = False): - return torch.ops._C.sgmv_expand(inputs, lora_b_weights, - lora_indices_tensor, seq_len_tensor, - output_tensor, slice_offset, slice_size) + return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights, + lora_indices_tensor, seq_len_tensor, + output_tensor, slice_offset, + slice_size) diff --git a/vllm_ascend/meta_registration.py b/vllm_ascend/meta_registration.py index 47c7758..9a58afd 100644 --- a/vllm_ascend/meta_registration.py +++ b/vllm_ascend/meta_registration.py @@ -23,7 +23,7 @@ from torch.library import Library # Do NOT perform any real computation or allocate device memory. # # 2. Register your meta function using `register_meta_if_necessary`, providing: -# - The namespace (usually "_C" for custom ops) +# - The namespace (usually "_C_ascend" for custom ops) # - The operator name (as registered in C++) # - The Python meta function # - (Optional) The overload name, if your op has overloads @@ -39,7 +39,7 @@ from torch.library import Library # # For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors -lib = Library("_C", "IMPL") +lib = Library("_C_ascend", "IMPL") def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""): @@ -97,8 +97,9 @@ def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor, return y_out -register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta) -register_meta_if_necessary("_C", "get_masked_input_and_mask", +register_meta_if_necessary("_C_ascend", "rotary_embedding", + rotary_embedding_meta) +register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask", get_masked_input_and_mask_meta) -register_meta_if_necessary("_C", "bgmv_expand", bgmv_expand_meta) -register_meta_if_necessary("_C", "sgmv_expand", sgmv_expand_meta) +register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta) +register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta) diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 5c8a798..381c1b6 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -35,19 +35,20 @@ class dummyFusionOp: def register_dummy_fusion_op() -> None: - torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm") - torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm") - torch.ops._C.static_scaled_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm") + torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp( + name="fused_add_rms_norm") + torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp( name="static_scaled_fp8_quant") - torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp( name="dynamic_scaled_fp8_quant") - torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp( name="dynamic_per_token_scaled_fp8_quant") - torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp( name="rms_norm_static_fp8_quant") - torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( + torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp( name="fused_add_rms_norm_static_fp8_quant") - torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp( + torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp( name="rms_norm_dynamic_per_token_quant") diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 4b76dce..9ddf280 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -49,7 +49,7 @@ def _rope_forward_oot( # adopt custom kernel path for rotary_embedding if _custom_rotary_embedding_enabled(query, is_neox_style, self.head_size) and not is_310p(): - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, query, key, diff --git a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py index 766ae5f..e64bd6f 100644 --- a/vllm_ascend/torchair/ops/torchair_rotary_embedding.py +++ b/vllm_ascend/torchair/ops/torchair_rotary_embedding.py @@ -62,7 +62,7 @@ def rope_forward_oot( # adopt custom kernel path for rotary_embedding if custom_rotary_embedding_enabled(query, neox_style, self.head_size) and not is_310p(): - query, key = torch.ops._C.rotary_embedding( + query, key = torch.ops._C_ascend.rotary_embedding( positions, query, key, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index ca51327..a166061 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -24,7 +24,7 @@ import os from contextlib import contextmanager from enum import Enum from threading import Lock -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import torch import torch_npu # noqa: F401 # noqa: F401 @@ -188,7 +188,7 @@ def try_register_lib(lib_name: str, lib_info: str = ""): def enable_custom_op(): """ - Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component. + Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component. Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device(). """ global _CUSTOM_OP_ENABLED @@ -486,7 +486,7 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): """Register Ascend CustomOP - NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, + NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, and ensure this will execute after model config is initilazed. """ global _ASCEND_CUSTOMOP_IS_REIGISTERED @@ -589,3 +589,31 @@ def dense_optim_enable() -> bool: def is_moe_model(vllm_config: VllmConfig): config = vllm_config.model_config.hf_config return any('experts' in key.lower() for key in config.to_dict()) + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + return torch.ops._C_ascend.weak_ref_tensor(tensor) + else: + return tensor + + +def weak_ref_tensors( + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + raise ValueError("Invalid type for tensors")