init
This commit is contained in:
398
torch_vacc/contrib/transfer_to_vacc.py
Normal file
398
torch_vacc/contrib/transfer_to_vacc.py
Normal file
@@ -0,0 +1,398 @@
|
||||
import os
|
||||
import warnings
|
||||
import logging as logger
|
||||
from functools import wraps
|
||||
import torch
|
||||
import torch_vacc
|
||||
|
||||
'''
|
||||
try:
|
||||
import torchair
|
||||
except ImportError:
|
||||
IS_TORCHAIR_INSTALLED = False
|
||||
else:
|
||||
IS_TORCHAIR_INSTALLED = True
|
||||
'''
|
||||
|
||||
warnings.filterwarnings(action="once")
|
||||
|
||||
|
||||
torch_fn_white_list = [
|
||||
"_cudnn_init_dropout_state",
|
||||
"_empty_affine_quantized",
|
||||
"_empty_per_channel_affine_quantized",
|
||||
"_pin_memory",
|
||||
"_sparse_coo_tensor_unsafe",
|
||||
"_sparse_csr_tensor_unsafe",
|
||||
"logspace",
|
||||
"randint",
|
||||
"hann_window",
|
||||
"rand",
|
||||
"full_like",
|
||||
"ones_like",
|
||||
"rand_like",
|
||||
"randperm",
|
||||
"arange",
|
||||
"frombuffer",
|
||||
"normal",
|
||||
"empty_strided",
|
||||
"empty_like",
|
||||
"scalar_tensor",
|
||||
"tril_indices",
|
||||
"bartlett_window",
|
||||
"ones",
|
||||
"sparse_coo_tensor",
|
||||
"randn",
|
||||
"kaiser_window",
|
||||
"tensor",
|
||||
"triu_indices",
|
||||
"as_tensor",
|
||||
"zeros",
|
||||
"randint_like",
|
||||
"full",
|
||||
"eye",
|
||||
"empty",
|
||||
"blackman_window",
|
||||
"zeros_like",
|
||||
"range",
|
||||
"sparse_csr_tensor",
|
||||
"randn_like",
|
||||
"from_file",
|
||||
"linspace",
|
||||
"hamming_window",
|
||||
"empty_quantized",
|
||||
"autocast",
|
||||
"load",
|
||||
]
|
||||
torch_tensor_fn_white_list = [
|
||||
"new_empty",
|
||||
"new_empty_strided",
|
||||
"new_full",
|
||||
"new_ones",
|
||||
"new_tensor",
|
||||
"new_zeros",
|
||||
"to",
|
||||
]
|
||||
torch_module_fn_white_list = ["to", "to_empty"]
|
||||
torch_cuda_fn_white_list = [
|
||||
"get_device_properties",
|
||||
"get_device_name",
|
||||
"get_device_capability",
|
||||
"list_gpu_processes",
|
||||
"set_device",
|
||||
"synchronize",
|
||||
"mem_get_info",
|
||||
"memory_stats",
|
||||
"memory_summary",
|
||||
"memory_allocated",
|
||||
"max_memory_allocated",
|
||||
"reset_max_memory_allocated",
|
||||
"memory_reserved",
|
||||
"max_memory_reserved",
|
||||
"reset_max_memory_cached",
|
||||
"reset_peak_memory_stats",
|
||||
"current_stream",
|
||||
"default_stream",
|
||||
]
|
||||
torch_profiler_fn_white_list = ["profile"]
|
||||
torch_distributed_fn_white_list = ["__init__"]
|
||||
device_kwargs_list = ["device", "device_type", "map_location"]
|
||||
|
||||
|
||||
def wrapper_cuda(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
replace_int = fn.__name__ in ["to", "to_empty"]
|
||||
if args:
|
||||
args_new = list(args)
|
||||
args = replace_cuda_to_vacc_in_list(args_new, replace_int)
|
||||
if kwargs:
|
||||
for device_arg in device_kwargs_list:
|
||||
device = kwargs.get(device_arg, None)
|
||||
if device is not None:
|
||||
replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device)
|
||||
device_ids = kwargs.get("device_ids", None)
|
||||
if type(device_ids) == list:
|
||||
device_ids = replace_cuda_to_vacc_in_list(device_ids, replace_int)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device):
|
||||
if type(device) == str and "cuda" in device:
|
||||
kwargs[device_arg] = device.replace("cuda", "vacc")
|
||||
elif type(device) == torch.device and "cuda" in device.type:
|
||||
device_info = (
|
||||
"vacc:{}".format(device.index) if device.index is not None else "vacc"
|
||||
)
|
||||
kwargs[device_arg] = torch.device(device_info)
|
||||
elif type(device) == int:
|
||||
kwargs[device_arg] = f"vacc:{device}"
|
||||
elif type(device) == dict:
|
||||
kwargs[device_arg] = replace_cuda_to_vacc_in_dict(device)
|
||||
|
||||
|
||||
def replace_cuda_to_vacc_in_list(args_list, replace_int):
|
||||
for idx, arg in enumerate(args_list):
|
||||
if isinstance(arg, str) and "cuda" in arg:
|
||||
args_list[idx] = arg.replace("cuda", "vacc")
|
||||
elif isinstance(arg, torch.device) and "cuda" in arg.type:
|
||||
device_info = (
|
||||
"vacc:{}".format(arg.index) if arg.index is not None else "vacc"
|
||||
)
|
||||
args_list[idx] = torch.device(device_info)
|
||||
elif replace_int and not isinstance(arg, bool) and isinstance(arg, int):
|
||||
args_list[idx] = f"vacc:{arg}"
|
||||
elif isinstance(arg, dict):
|
||||
args_list[idx] = replace_cuda_to_vacc_in_dict(arg)
|
||||
return args_list
|
||||
|
||||
|
||||
def replace_cuda_to_vacc_in_dict(device_dict):
|
||||
new_dict = {}
|
||||
for key, value in device_dict.items():
|
||||
if isinstance(key, str):
|
||||
key = key.replace("cuda", "vacc")
|
||||
if isinstance(value, str):
|
||||
value = value.replace("cuda", "vacc")
|
||||
new_dict[key] = value
|
||||
return new_dict
|
||||
|
||||
|
||||
def device_wrapper(enter_fn, white_list):
|
||||
for fn_name in white_list:
|
||||
fn = getattr(enter_fn, fn_name, None)
|
||||
if fn:
|
||||
setattr(enter_fn, fn_name, wrapper_cuda(fn))
|
||||
|
||||
|
||||
def wrapper_vccl(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
if args:
|
||||
args_new = list(args)
|
||||
for idx, arg in enumerate(args_new):
|
||||
if type(arg) == str and "nccl" in arg:
|
||||
args_new[idx] = arg.replace("nccl", "vccl")
|
||||
args = args_new
|
||||
if kwargs:
|
||||
if type(kwargs.get("backend", None)) == str:
|
||||
kwargs["backend"] = "vccl"
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def wrapper_data_loader(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
if kwargs:
|
||||
pin_memory = kwargs.get("pin_memory", False)
|
||||
pin_memory_device = kwargs.get("pin_memory_device", None)
|
||||
if pin_memory and not pin_memory_device:
|
||||
kwargs["pin_memory_device"] = "vacc"
|
||||
if (
|
||||
pin_memory
|
||||
and type(pin_memory_device) == str
|
||||
and "cuda" in pin_memory_device
|
||||
):
|
||||
kwargs["pin_memory_device"] = pin_memory_device.replace("cuda", "vacc")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
def wrapper_get_available_device_type(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
try:
|
||||
if (torch.vacc.is_available()):
|
||||
return 'vacc'
|
||||
except Exception as e:
|
||||
msg = "vacc device is not available."
|
||||
warnings.warn(msg, RuntimeWarning)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
'''
|
||||
def wrapper_profiler(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
if kwargs:
|
||||
if (
|
||||
"experimental_config" in kwargs.keys()
|
||||
and type(kwargs.get("experimental_config"))
|
||||
!= torch_vacc.profiler._ExperimentalConfig
|
||||
):
|
||||
logger.warning(
|
||||
"The parameter experimental_config of torch.profiler.profile has been deleted by the tool "
|
||||
"because it can only be used in cuda, please manually modify the code "
|
||||
"and use the experimental_config parameter adapted to vacc."
|
||||
)
|
||||
del kwargs["experimental_config"]
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def wrapper_compile(fn):
|
||||
@wraps(fn)
|
||||
def decorated(*args, **kwargs):
|
||||
vacc_backend = torchair.get_vacc_backend()
|
||||
if kwargs:
|
||||
backend = kwargs.get("backend", None)
|
||||
if (
|
||||
not backend
|
||||
or not isinstance(backend, functools.partial)
|
||||
or not isinstance(backend.func, type(vacc_backend.func))
|
||||
):
|
||||
kwargs["backend"] = vacc_backend
|
||||
else:
|
||||
kwargs["backend"] = vacc_backend
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
'''
|
||||
|
||||
|
||||
def jit_script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None):
|
||||
msg = "torch.jit.script will be disabled by transfer_to_vacc, which currently does not support it."
|
||||
warnings.warn(msg, RuntimeWarning)
|
||||
return obj
|
||||
|
||||
|
||||
def patch_cuda():
|
||||
patchs = [
|
||||
["cuda", torch_vacc.vacc],
|
||||
["cuda.amp", torch_vacc.vacc.amp],
|
||||
["cuda.random", torch_vacc.vacc.random],
|
||||
["cuda.amp.autocast_mode", torch_vacc.vacc.amp.autocast_mode],
|
||||
["cuda.amp.common", torch_vacc.vacc.amp.common],
|
||||
["cuda.amp.grad_scaler", torch_vacc.vacc.amp.grad_scaler],
|
||||
]
|
||||
torch_vacc._apply_patches(patchs)
|
||||
|
||||
|
||||
'''
|
||||
def patch_profiler():
|
||||
patchs = [
|
||||
["profiler.profile", torch_vacc.profiler.profile],
|
||||
["profiler.schedule", torch_vacc.profiler.schedule],
|
||||
[
|
||||
"profiler.tensorboard_trace_handler",
|
||||
torch_vacc.profiler.tensorboard_trace_handler,
|
||||
],
|
||||
["profiler.ProfilerAction", torch_vacc.profiler.ProfilerAction],
|
||||
["profiler.ProfilerActivity.CUDA", torch_vacc.profiler.ProfilerActivity.VACC],
|
||||
["profiler.ProfilerActivity.CPU", torch_vacc.profiler.ProfilerActivity.CPU],
|
||||
]
|
||||
torch_vacc._apply_patches(patchs)
|
||||
'''
|
||||
|
||||
|
||||
def warning_fn(msg, rank0=True):
|
||||
is_distributed = (
|
||||
torch.distributed.is_available()
|
||||
and torch.distributed.is_initialized()
|
||||
and torch.distributed.get_world_size() > 1
|
||||
)
|
||||
env_rank = os.getenv("RANK", None)
|
||||
|
||||
if rank0 and is_distributed:
|
||||
if torch.distributed.get_rank() == 0:
|
||||
warnings.warn(msg, ImportWarning)
|
||||
elif rank0 and env_rank:
|
||||
if env_rank == "0":
|
||||
warnings.warn(msg, ImportWarning)
|
||||
else:
|
||||
warnings.warn(msg, ImportWarning)
|
||||
|
||||
|
||||
def init():
|
||||
warning_fn(
|
||||
"""
|
||||
*************************************************************************************************************
|
||||
The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.vacc and torch.nn.Module.vacc now..
|
||||
The torch.cuda.DoubleTensor is replaced with torch.vacc.FloatTensor cause the double type is not supported now..
|
||||
The backend in torch.distributed.init_process_group set to vccl now..
|
||||
The torch.cuda.* and torch.cuda.amp.* are replaced with torch.vacc.* and torch.vacc.amp.* now..
|
||||
The device parameters have been replaced with vacc in the function below:
|
||||
{}
|
||||
If you notices any functions you use is not included in the above list, feel free to contact torch-vacc development team.
|
||||
*************************************************************************************************************
|
||||
""".format(
|
||||
", ".join(
|
||||
["torch." + i for i in torch_fn_white_list]
|
||||
+ ["torch.Tensor." + i for i in torch_tensor_fn_white_list]
|
||||
+ ["torch.nn.Module." + i for i in torch_module_fn_white_list]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# torch.cuda.*
|
||||
patch_cuda()
|
||||
device_wrapper(torch.cuda, torch_cuda_fn_white_list)
|
||||
|
||||
# torch.profiler.*
|
||||
# TODO(qingsong): profiler not implemented yet
|
||||
# patch_profiler()
|
||||
# device_wrapper(torch.profiler, torch_profiler_fn_white_list)
|
||||
|
||||
# torch.*
|
||||
device_wrapper(torch, torch_fn_white_list)
|
||||
|
||||
# torch.Tensor.*
|
||||
device_wrapper(torch.Tensor, torch_tensor_fn_white_list)
|
||||
torch.Tensor.cuda = torch.Tensor.vacc
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_vacc
|
||||
|
||||
for dtype_tensor in [
|
||||
"ByteTensor",
|
||||
"CharTensor",
|
||||
"DoubleTensor",
|
||||
"FloatTensor",
|
||||
"IntTensor",
|
||||
"LongTensor",
|
||||
"ShortTensor",
|
||||
"HalfTensor",
|
||||
"BoolTensor",
|
||||
]:
|
||||
setattr(
|
||||
torch.cuda,
|
||||
dtype_tensor,
|
||||
getattr(torch.vacc, dtype_tensor),
|
||||
)
|
||||
# TODO(qingsong): do we need this? should we add LongTensor=IntTensor?
|
||||
torch.cuda.DoubleTensor = torch.vacc.FloatTensor
|
||||
|
||||
# torch.nn.Module.*
|
||||
device_wrapper(torch.nn.Module, torch_module_fn_white_list)
|
||||
torch.nn.Module.cuda = torch.nn.Module.vacc
|
||||
|
||||
# torch.distributed.init_process_group
|
||||
torch.distributed.init_process_group = wrapper_vccl(
|
||||
torch.distributed.init_process_group
|
||||
)
|
||||
torch.distributed.is_nccl_available = torch.distributed.is_vccl_available
|
||||
|
||||
# torch.nn.parallel.DistributedDataParallel
|
||||
device_wrapper(
|
||||
torch.nn.parallel.DistributedDataParallel, torch_distributed_fn_white_list
|
||||
)
|
||||
# torch.utils.data.DataLoader
|
||||
torch.utils.data.DataLoader.__init__ = wrapper_data_loader(
|
||||
torch.utils.data.DataLoader.__init__
|
||||
)
|
||||
|
||||
torch.jit.script = jit_script
|
||||
torch._utils._get_available_device_type = wrapper_get_available_device_type(
|
||||
torch._utils._get_available_device_type
|
||||
)
|
||||
|
||||
'''
|
||||
if IS_TORCHAIR_INSTALLED:
|
||||
torch.compile = wrapper_compile(torch.compile)
|
||||
'''
|
||||
|
||||
init()
|
||||
Reference in New Issue
Block a user