Files
enginex-vastai-va16-vllm/torch_vacc/contrib/transfer_to_vacc.py
2026-04-02 04:55:00 +00:00

399 lines
12 KiB
Python

import os
import warnings
import logging as logger
from functools import wraps
import torch
import torch_vacc
'''
try:
import torchair
except ImportError:
IS_TORCHAIR_INSTALLED = False
else:
IS_TORCHAIR_INSTALLED = True
'''
warnings.filterwarnings(action="once")
torch_fn_white_list = [
"_cudnn_init_dropout_state",
"_empty_affine_quantized",
"_empty_per_channel_affine_quantized",
"_pin_memory",
"_sparse_coo_tensor_unsafe",
"_sparse_csr_tensor_unsafe",
"logspace",
"randint",
"hann_window",
"rand",
"full_like",
"ones_like",
"rand_like",
"randperm",
"arange",
"frombuffer",
"normal",
"empty_strided",
"empty_like",
"scalar_tensor",
"tril_indices",
"bartlett_window",
"ones",
"sparse_coo_tensor",
"randn",
"kaiser_window",
"tensor",
"triu_indices",
"as_tensor",
"zeros",
"randint_like",
"full",
"eye",
"empty",
"blackman_window",
"zeros_like",
"range",
"sparse_csr_tensor",
"randn_like",
"from_file",
"linspace",
"hamming_window",
"empty_quantized",
"autocast",
"load",
]
torch_tensor_fn_white_list = [
"new_empty",
"new_empty_strided",
"new_full",
"new_ones",
"new_tensor",
"new_zeros",
"to",
]
torch_module_fn_white_list = ["to", "to_empty"]
torch_cuda_fn_white_list = [
"get_device_properties",
"get_device_name",
"get_device_capability",
"list_gpu_processes",
"set_device",
"synchronize",
"mem_get_info",
"memory_stats",
"memory_summary",
"memory_allocated",
"max_memory_allocated",
"reset_max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"reset_max_memory_cached",
"reset_peak_memory_stats",
"current_stream",
"default_stream",
]
torch_profiler_fn_white_list = ["profile"]
torch_distributed_fn_white_list = ["__init__"]
device_kwargs_list = ["device", "device_type", "map_location"]
def wrapper_cuda(fn):
@wraps(fn)
def decorated(*args, **kwargs):
replace_int = fn.__name__ in ["to", "to_empty"]
if args:
args_new = list(args)
args = replace_cuda_to_vacc_in_list(args_new, replace_int)
if kwargs:
for device_arg in device_kwargs_list:
device = kwargs.get(device_arg, None)
if device is not None:
replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device)
device_ids = kwargs.get("device_ids", None)
if type(device_ids) == list:
device_ids = replace_cuda_to_vacc_in_list(device_ids, replace_int)
return fn(*args, **kwargs)
return decorated
def replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device):
if type(device) == str and "cuda" in device:
kwargs[device_arg] = device.replace("cuda", "vacc")
elif type(device) == torch.device and "cuda" in device.type:
device_info = (
"vacc:{}".format(device.index) if device.index is not None else "vacc"
)
kwargs[device_arg] = torch.device(device_info)
elif type(device) == int:
kwargs[device_arg] = f"vacc:{device}"
elif type(device) == dict:
kwargs[device_arg] = replace_cuda_to_vacc_in_dict(device)
def replace_cuda_to_vacc_in_list(args_list, replace_int):
for idx, arg in enumerate(args_list):
if isinstance(arg, str) and "cuda" in arg:
args_list[idx] = arg.replace("cuda", "vacc")
elif isinstance(arg, torch.device) and "cuda" in arg.type:
device_info = (
"vacc:{}".format(arg.index) if arg.index is not None else "vacc"
)
args_list[idx] = torch.device(device_info)
elif replace_int and not isinstance(arg, bool) and isinstance(arg, int):
args_list[idx] = f"vacc:{arg}"
elif isinstance(arg, dict):
args_list[idx] = replace_cuda_to_vacc_in_dict(arg)
return args_list
def replace_cuda_to_vacc_in_dict(device_dict):
new_dict = {}
for key, value in device_dict.items():
if isinstance(key, str):
key = key.replace("cuda", "vacc")
if isinstance(value, str):
value = value.replace("cuda", "vacc")
new_dict[key] = value
return new_dict
def device_wrapper(enter_fn, white_list):
for fn_name in white_list:
fn = getattr(enter_fn, fn_name, None)
if fn:
setattr(enter_fn, fn_name, wrapper_cuda(fn))
def wrapper_vccl(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if args:
args_new = list(args)
for idx, arg in enumerate(args_new):
if type(arg) == str and "nccl" in arg:
args_new[idx] = arg.replace("nccl", "vccl")
args = args_new
if kwargs:
if type(kwargs.get("backend", None)) == str:
kwargs["backend"] = "vccl"
return fn(*args, **kwargs)
return decorated
def wrapper_data_loader(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if kwargs:
pin_memory = kwargs.get("pin_memory", False)
pin_memory_device = kwargs.get("pin_memory_device", None)
if pin_memory and not pin_memory_device:
kwargs["pin_memory_device"] = "vacc"
if (
pin_memory
and type(pin_memory_device) == str
and "cuda" in pin_memory_device
):
kwargs["pin_memory_device"] = pin_memory_device.replace("cuda", "vacc")
return fn(*args, **kwargs)
return decorated
def wrapper_get_available_device_type(fn):
@wraps(fn)
def decorated(*args, **kwargs):
try:
if (torch.vacc.is_available()):
return 'vacc'
except Exception as e:
msg = "vacc device is not available."
warnings.warn(msg, RuntimeWarning)
return fn(*args, **kwargs)
return decorated
'''
def wrapper_profiler(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if kwargs:
if (
"experimental_config" in kwargs.keys()
and type(kwargs.get("experimental_config"))
!= torch_vacc.profiler._ExperimentalConfig
):
logger.warning(
"The parameter experimental_config of torch.profiler.profile has been deleted by the tool "
"because it can only be used in cuda, please manually modify the code "
"and use the experimental_config parameter adapted to vacc."
)
del kwargs["experimental_config"]
return fn(*args, **kwargs)
return decorated
def wrapper_compile(fn):
@wraps(fn)
def decorated(*args, **kwargs):
vacc_backend = torchair.get_vacc_backend()
if kwargs:
backend = kwargs.get("backend", None)
if (
not backend
or not isinstance(backend, functools.partial)
or not isinstance(backend.func, type(vacc_backend.func))
):
kwargs["backend"] = vacc_backend
else:
kwargs["backend"] = vacc_backend
return fn(*args, **kwargs)
return decorated
'''
def jit_script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None):
msg = "torch.jit.script will be disabled by transfer_to_vacc, which currently does not support it."
warnings.warn(msg, RuntimeWarning)
return obj
def patch_cuda():
patchs = [
["cuda", torch_vacc.vacc],
["cuda.amp", torch_vacc.vacc.amp],
["cuda.random", torch_vacc.vacc.random],
["cuda.amp.autocast_mode", torch_vacc.vacc.amp.autocast_mode],
["cuda.amp.common", torch_vacc.vacc.amp.common],
["cuda.amp.grad_scaler", torch_vacc.vacc.amp.grad_scaler],
]
torch_vacc._apply_patches(patchs)
'''
def patch_profiler():
patchs = [
["profiler.profile", torch_vacc.profiler.profile],
["profiler.schedule", torch_vacc.profiler.schedule],
[
"profiler.tensorboard_trace_handler",
torch_vacc.profiler.tensorboard_trace_handler,
],
["profiler.ProfilerAction", torch_vacc.profiler.ProfilerAction],
["profiler.ProfilerActivity.CUDA", torch_vacc.profiler.ProfilerActivity.VACC],
["profiler.ProfilerActivity.CPU", torch_vacc.profiler.ProfilerActivity.CPU],
]
torch_vacc._apply_patches(patchs)
'''
def warning_fn(msg, rank0=True):
is_distributed = (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
)
env_rank = os.getenv("RANK", None)
if rank0 and is_distributed:
if torch.distributed.get_rank() == 0:
warnings.warn(msg, ImportWarning)
elif rank0 and env_rank:
if env_rank == "0":
warnings.warn(msg, ImportWarning)
else:
warnings.warn(msg, ImportWarning)
def init():
warning_fn(
"""
*************************************************************************************************************
The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.vacc and torch.nn.Module.vacc now..
The torch.cuda.DoubleTensor is replaced with torch.vacc.FloatTensor cause the double type is not supported now..
The backend in torch.distributed.init_process_group set to vccl now..
The torch.cuda.* and torch.cuda.amp.* are replaced with torch.vacc.* and torch.vacc.amp.* now..
The device parameters have been replaced with vacc in the function below:
{}
If you notices any functions you use is not included in the above list, feel free to contact torch-vacc development team.
*************************************************************************************************************
""".format(
", ".join(
["torch." + i for i in torch_fn_white_list]
+ ["torch.Tensor." + i for i in torch_tensor_fn_white_list]
+ ["torch.nn.Module." + i for i in torch_module_fn_white_list]
)
)
)
# torch.cuda.*
patch_cuda()
device_wrapper(torch.cuda, torch_cuda_fn_white_list)
# torch.profiler.*
# TODO(qingsong): profiler not implemented yet
# patch_profiler()
# device_wrapper(torch.profiler, torch_profiler_fn_white_list)
# torch.*
device_wrapper(torch, torch_fn_white_list)
# torch.Tensor.*
device_wrapper(torch.Tensor, torch_tensor_fn_white_list)
torch.Tensor.cuda = torch.Tensor.vacc
torch.Tensor.is_cuda = torch.Tensor.is_vacc
for dtype_tensor in [
"ByteTensor",
"CharTensor",
"DoubleTensor",
"FloatTensor",
"IntTensor",
"LongTensor",
"ShortTensor",
"HalfTensor",
"BoolTensor",
]:
setattr(
torch.cuda,
dtype_tensor,
getattr(torch.vacc, dtype_tensor),
)
# TODO(qingsong): do we need this? should we add LongTensor=IntTensor?
torch.cuda.DoubleTensor = torch.vacc.FloatTensor
# torch.nn.Module.*
device_wrapper(torch.nn.Module, torch_module_fn_white_list)
torch.nn.Module.cuda = torch.nn.Module.vacc
# torch.distributed.init_process_group
torch.distributed.init_process_group = wrapper_vccl(
torch.distributed.init_process_group
)
torch.distributed.is_nccl_available = torch.distributed.is_vccl_available
# torch.nn.parallel.DistributedDataParallel
device_wrapper(
torch.nn.parallel.DistributedDataParallel, torch_distributed_fn_white_list
)
# torch.utils.data.DataLoader
torch.utils.data.DataLoader.__init__ = wrapper_data_loader(
torch.utils.data.DataLoader.__init__
)
torch.jit.script = jit_script
torch._utils._get_available_device_type = wrapper_get_available_device_type(
torch._utils._get_available_device_type
)
'''
if IS_TORCHAIR_INSTALLED:
torch.compile = wrapper_compile(torch.compile)
'''
init()