import os import warnings import logging as logger from functools import wraps import torch import torch_vacc ''' try: import torchair except ImportError: IS_TORCHAIR_INSTALLED = False else: IS_TORCHAIR_INSTALLED = True ''' warnings.filterwarnings(action="once") torch_fn_white_list = [ "_cudnn_init_dropout_state", "_empty_affine_quantized", "_empty_per_channel_affine_quantized", "_pin_memory", "_sparse_coo_tensor_unsafe", "_sparse_csr_tensor_unsafe", "logspace", "randint", "hann_window", "rand", "full_like", "ones_like", "rand_like", "randperm", "arange", "frombuffer", "normal", "empty_strided", "empty_like", "scalar_tensor", "tril_indices", "bartlett_window", "ones", "sparse_coo_tensor", "randn", "kaiser_window", "tensor", "triu_indices", "as_tensor", "zeros", "randint_like", "full", "eye", "empty", "blackman_window", "zeros_like", "range", "sparse_csr_tensor", "randn_like", "from_file", "linspace", "hamming_window", "empty_quantized", "autocast", "load", ] torch_tensor_fn_white_list = [ "new_empty", "new_empty_strided", "new_full", "new_ones", "new_tensor", "new_zeros", "to", ] torch_module_fn_white_list = ["to", "to_empty"] torch_cuda_fn_white_list = [ "get_device_properties", "get_device_name", "get_device_capability", "list_gpu_processes", "set_device", "synchronize", "mem_get_info", "memory_stats", "memory_summary", "memory_allocated", "max_memory_allocated", "reset_max_memory_allocated", "memory_reserved", "max_memory_reserved", "reset_max_memory_cached", "reset_peak_memory_stats", "current_stream", "default_stream", ] torch_profiler_fn_white_list = ["profile"] torch_distributed_fn_white_list = ["__init__"] device_kwargs_list = ["device", "device_type", "map_location"] def wrapper_cuda(fn): @wraps(fn) def decorated(*args, **kwargs): replace_int = fn.__name__ in ["to", "to_empty"] if args: args_new = list(args) args = replace_cuda_to_vacc_in_list(args_new, replace_int) if kwargs: for device_arg in device_kwargs_list: device = kwargs.get(device_arg, None) if device is not None: replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device) device_ids = kwargs.get("device_ids", None) if type(device_ids) == list: device_ids = replace_cuda_to_vacc_in_list(device_ids, replace_int) return fn(*args, **kwargs) return decorated def replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device): if type(device) == str and "cuda" in device: kwargs[device_arg] = device.replace("cuda", "vacc") elif type(device) == torch.device and "cuda" in device.type: device_info = ( "vacc:{}".format(device.index) if device.index is not None else "vacc" ) kwargs[device_arg] = torch.device(device_info) elif type(device) == int: kwargs[device_arg] = f"vacc:{device}" elif type(device) == dict: kwargs[device_arg] = replace_cuda_to_vacc_in_dict(device) def replace_cuda_to_vacc_in_list(args_list, replace_int): for idx, arg in enumerate(args_list): if isinstance(arg, str) and "cuda" in arg: args_list[idx] = arg.replace("cuda", "vacc") elif isinstance(arg, torch.device) and "cuda" in arg.type: device_info = ( "vacc:{}".format(arg.index) if arg.index is not None else "vacc" ) args_list[idx] = torch.device(device_info) elif replace_int and not isinstance(arg, bool) and isinstance(arg, int): args_list[idx] = f"vacc:{arg}" elif isinstance(arg, dict): args_list[idx] = replace_cuda_to_vacc_in_dict(arg) return args_list def replace_cuda_to_vacc_in_dict(device_dict): new_dict = {} for key, value in device_dict.items(): if isinstance(key, str): key = key.replace("cuda", "vacc") if isinstance(value, str): value = value.replace("cuda", "vacc") new_dict[key] = value return new_dict def device_wrapper(enter_fn, white_list): for fn_name in white_list: fn = getattr(enter_fn, fn_name, None) if fn: setattr(enter_fn, fn_name, wrapper_cuda(fn)) def wrapper_vccl(fn): @wraps(fn) def decorated(*args, **kwargs): if args: args_new = list(args) for idx, arg in enumerate(args_new): if type(arg) == str and "nccl" in arg: args_new[idx] = arg.replace("nccl", "vccl") args = args_new if kwargs: if type(kwargs.get("backend", None)) == str: kwargs["backend"] = "vccl" return fn(*args, **kwargs) return decorated def wrapper_data_loader(fn): @wraps(fn) def decorated(*args, **kwargs): if kwargs: pin_memory = kwargs.get("pin_memory", False) pin_memory_device = kwargs.get("pin_memory_device", None) if pin_memory and not pin_memory_device: kwargs["pin_memory_device"] = "vacc" if ( pin_memory and type(pin_memory_device) == str and "cuda" in pin_memory_device ): kwargs["pin_memory_device"] = pin_memory_device.replace("cuda", "vacc") return fn(*args, **kwargs) return decorated def wrapper_get_available_device_type(fn): @wraps(fn) def decorated(*args, **kwargs): try: if (torch.vacc.is_available()): return 'vacc' except Exception as e: msg = "vacc device is not available." warnings.warn(msg, RuntimeWarning) return fn(*args, **kwargs) return decorated ''' def wrapper_profiler(fn): @wraps(fn) def decorated(*args, **kwargs): if kwargs: if ( "experimental_config" in kwargs.keys() and type(kwargs.get("experimental_config")) != torch_vacc.profiler._ExperimentalConfig ): logger.warning( "The parameter experimental_config of torch.profiler.profile has been deleted by the tool " "because it can only be used in cuda, please manually modify the code " "and use the experimental_config parameter adapted to vacc." ) del kwargs["experimental_config"] return fn(*args, **kwargs) return decorated def wrapper_compile(fn): @wraps(fn) def decorated(*args, **kwargs): vacc_backend = torchair.get_vacc_backend() if kwargs: backend = kwargs.get("backend", None) if ( not backend or not isinstance(backend, functools.partial) or not isinstance(backend.func, type(vacc_backend.func)) ): kwargs["backend"] = vacc_backend else: kwargs["backend"] = vacc_backend return fn(*args, **kwargs) return decorated ''' def jit_script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None): msg = "torch.jit.script will be disabled by transfer_to_vacc, which currently does not support it." warnings.warn(msg, RuntimeWarning) return obj def patch_cuda(): patchs = [ ["cuda", torch_vacc.vacc], ["cuda.amp", torch_vacc.vacc.amp], ["cuda.random", torch_vacc.vacc.random], ["cuda.amp.autocast_mode", torch_vacc.vacc.amp.autocast_mode], ["cuda.amp.common", torch_vacc.vacc.amp.common], ["cuda.amp.grad_scaler", torch_vacc.vacc.amp.grad_scaler], ] torch_vacc._apply_patches(patchs) ''' def patch_profiler(): patchs = [ ["profiler.profile", torch_vacc.profiler.profile], ["profiler.schedule", torch_vacc.profiler.schedule], [ "profiler.tensorboard_trace_handler", torch_vacc.profiler.tensorboard_trace_handler, ], ["profiler.ProfilerAction", torch_vacc.profiler.ProfilerAction], ["profiler.ProfilerActivity.CUDA", torch_vacc.profiler.ProfilerActivity.VACC], ["profiler.ProfilerActivity.CPU", torch_vacc.profiler.ProfilerActivity.CPU], ] torch_vacc._apply_patches(patchs) ''' def warning_fn(msg, rank0=True): is_distributed = ( torch.distributed.is_available() and torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1 ) env_rank = os.getenv("RANK", None) if rank0 and is_distributed: if torch.distributed.get_rank() == 0: warnings.warn(msg, ImportWarning) elif rank0 and env_rank: if env_rank == "0": warnings.warn(msg, ImportWarning) else: warnings.warn(msg, ImportWarning) def init(): warning_fn( """ ************************************************************************************************************* The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.vacc and torch.nn.Module.vacc now.. The torch.cuda.DoubleTensor is replaced with torch.vacc.FloatTensor cause the double type is not supported now.. The backend in torch.distributed.init_process_group set to vccl now.. The torch.cuda.* and torch.cuda.amp.* are replaced with torch.vacc.* and torch.vacc.amp.* now.. The device parameters have been replaced with vacc in the function below: {} If you notices any functions you use is not included in the above list, feel free to contact torch-vacc development team. ************************************************************************************************************* """.format( ", ".join( ["torch." + i for i in torch_fn_white_list] + ["torch.Tensor." + i for i in torch_tensor_fn_white_list] + ["torch.nn.Module." + i for i in torch_module_fn_white_list] ) ) ) # torch.cuda.* patch_cuda() device_wrapper(torch.cuda, torch_cuda_fn_white_list) # torch.profiler.* # TODO(qingsong): profiler not implemented yet # patch_profiler() # device_wrapper(torch.profiler, torch_profiler_fn_white_list) # torch.* device_wrapper(torch, torch_fn_white_list) # torch.Tensor.* device_wrapper(torch.Tensor, torch_tensor_fn_white_list) torch.Tensor.cuda = torch.Tensor.vacc torch.Tensor.is_cuda = torch.Tensor.is_vacc for dtype_tensor in [ "ByteTensor", "CharTensor", "DoubleTensor", "FloatTensor", "IntTensor", "LongTensor", "ShortTensor", "HalfTensor", "BoolTensor", ]: setattr( torch.cuda, dtype_tensor, getattr(torch.vacc, dtype_tensor), ) # TODO(qingsong): do we need this? should we add LongTensor=IntTensor? torch.cuda.DoubleTensor = torch.vacc.FloatTensor # torch.nn.Module.* device_wrapper(torch.nn.Module, torch_module_fn_white_list) torch.nn.Module.cuda = torch.nn.Module.vacc # torch.distributed.init_process_group torch.distributed.init_process_group = wrapper_vccl( torch.distributed.init_process_group ) torch.distributed.is_nccl_available = torch.distributed.is_vccl_available # torch.nn.parallel.DistributedDataParallel device_wrapper( torch.nn.parallel.DistributedDataParallel, torch_distributed_fn_white_list ) # torch.utils.data.DataLoader torch.utils.data.DataLoader.__init__ = wrapper_data_loader( torch.utils.data.DataLoader.__init__ ) torch.jit.script = jit_script torch._utils._get_available_device_type = wrapper_get_available_device_type( torch._utils._get_available_device_type ) ''' if IS_TORCHAIR_INSTALLED: torch.compile = wrapper_compile(torch.compile) ''' init()