init
This commit is contained in:
124
torch_vacc/__init__.py
Normal file
124
torch_vacc/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import atexit
|
||||||
|
import ctypes
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from .version import __version__
|
||||||
|
|
||||||
|
|
||||||
|
def register_runtime_libraries() -> None:
|
||||||
|
try:
|
||||||
|
libpython_so = f"libpython{sys.version_info.major}.{sys.version_info.minor}.so"
|
||||||
|
base_prefix = getattr(sys, "base_prefix", sys.prefix)
|
||||||
|
if not base_prefix.startswith("/usr"): # like conda or virtualenv
|
||||||
|
ctypes.CDLL(os.path.join(base_prefix, "lib", libpython_so))
|
||||||
|
|
||||||
|
this_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
rt_dll_dpath = os.path.join(this_path, "_vacc_libs")
|
||||||
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libodsp.so"))
|
||||||
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvaccrt.so"))
|
||||||
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvnnl.so"))
|
||||||
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvccl.so"))
|
||||||
|
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvacc_core.so"))
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("Vastai runtime library not loaded.") from e
|
||||||
|
|
||||||
|
|
||||||
|
register_runtime_libraries()
|
||||||
|
|
||||||
|
from ._vacc_libs import _torch_vacc as _C
|
||||||
|
|
||||||
|
try:
|
||||||
|
_C._init_torch_vacc_module()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("Failed to init torch_vacc.") from e
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_patches(monkey_patches):
|
||||||
|
def _getattr(module_list, root_module=torch):
|
||||||
|
if len(module_list) <= 1:
|
||||||
|
return root_module
|
||||||
|
|
||||||
|
if hasattr(root_module, module_list[0]):
|
||||||
|
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
|
||||||
|
else:
|
||||||
|
empty_module_name = f"{root_module.__name__}.{module_list[0]}"
|
||||||
|
sys.modules[empty_module_name] = types.ModuleType(empty_module_name)
|
||||||
|
setattr(root_module, module_list[0], sys.modules.get(empty_module_name))
|
||||||
|
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
|
||||||
|
|
||||||
|
for patch_pair in monkey_patches:
|
||||||
|
dest, patch = patch_pair
|
||||||
|
dest_module = _getattr(dest.split("."), root_module=torch)
|
||||||
|
last_module_level = dest.split(".")[-1]
|
||||||
|
if not isinstance(patch, types.ModuleType):
|
||||||
|
setattr(dest_module, last_module_level, patch)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not hasattr(dest_module, last_module_level) or not hasattr(patch, "__all__"):
|
||||||
|
setattr(dest_module, last_module_level, patch)
|
||||||
|
sys.modules[f"{dest_module.__name__}.{last_module_level}"] = patch
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert hasattr(patch, "__all__"), "Patch module must have __all__ definition."
|
||||||
|
dest_module = getattr(dest_module, last_module_level)
|
||||||
|
for attr in patch.__all__:
|
||||||
|
setattr(dest_module, attr, getattr(patch, attr))
|
||||||
|
|
||||||
|
|
||||||
|
import torch_vacc.vacc as vacc
|
||||||
|
|
||||||
|
# register "vacc" module/functions to torch
|
||||||
|
torch._register_device_module("vacc", vacc)
|
||||||
|
unsupported_dtype = [
|
||||||
|
torch.quint8,
|
||||||
|
torch.quint4x2,
|
||||||
|
torch.quint2x4,
|
||||||
|
torch.qint32,
|
||||||
|
torch.qint8,
|
||||||
|
]
|
||||||
|
torch.utils.generate_methods_for_privateuse1_backend(
|
||||||
|
for_tensor=True,
|
||||||
|
for_module=True,
|
||||||
|
for_storage=True, # TODO(qingsong): do we support storage?
|
||||||
|
unsupported_dtype=unsupported_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
# register legacy *DtypeTensor into torch.vacc
|
||||||
|
_C._initialize_python_bindings()
|
||||||
|
|
||||||
|
# init seed generators, vacc default generator
|
||||||
|
vacc.init()
|
||||||
|
|
||||||
|
|
||||||
|
def is_vccl_available() -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
torch.distributed.is_vccl_available = is_vccl_available
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_log_level(log_level):
|
||||||
|
_C.set_global_log_level(log_level.upper())
|
||||||
|
|
||||||
|
|
||||||
|
def print_vacc_ops():
|
||||||
|
_C._print_vacc_ops()
|
||||||
|
|
||||||
|
|
||||||
|
def vacc_ops_list():
|
||||||
|
return _C._vacc_ops_list().split(",")
|
||||||
|
|
||||||
|
|
||||||
|
def print_vacc_selective_ops():
|
||||||
|
_C._print_vacc_selective_ops()
|
||||||
|
|
||||||
|
|
||||||
|
def _vacc_shutdown():
|
||||||
|
_C._vacc_module_shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(_vacc_shutdown)
|
||||||
BIN
torch_vacc/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/__pycache__/version.cpython-312.pyc
Normal file
BIN
torch_vacc/__pycache__/version.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/_vacc_libs/_torch_vacc.cpython-312-x86_64-linux-gnu.so
Executable file
BIN
torch_vacc/_vacc_libs/_torch_vacc.cpython-312-x86_64-linux-gnu.so
Executable file
Binary file not shown.
BIN
torch_vacc/_vacc_libs/libfn-log.so
Executable file
BIN
torch_vacc/_vacc_libs/libfn-log.so
Executable file
Binary file not shown.
BIN
torch_vacc/_vacc_libs/libodsp.so
Executable file
BIN
torch_vacc/_vacc_libs/libodsp.so
Executable file
Binary file not shown.
BIN
torch_vacc/_vacc_libs/libvacc_core.so
Executable file
BIN
torch_vacc/_vacc_libs/libvacc_core.so
Executable file
Binary file not shown.
BIN
torch_vacc/_vacc_libs/libvaccrt.so
Executable file
BIN
torch_vacc/_vacc_libs/libvaccrt.so
Executable file
Binary file not shown.
BIN
torch_vacc/_vacc_libs/libvccl.so
Executable file
BIN
torch_vacc/_vacc_libs/libvccl.so
Executable file
Binary file not shown.
BIN
torch_vacc/_vacc_libs/libvnnl.so
Executable file
BIN
torch_vacc/_vacc_libs/libvnnl.so
Executable file
Binary file not shown.
0
torch_vacc/contrib/__init__.py
Normal file
0
torch_vacc/contrib/__init__.py
Normal file
BIN
torch_vacc/contrib/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/contrib/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/contrib/__pycache__/transfer_to_vacc.cpython-312.pyc
Normal file
BIN
torch_vacc/contrib/__pycache__/transfer_to_vacc.cpython-312.pyc
Normal file
Binary file not shown.
398
torch_vacc/contrib/transfer_to_vacc.py
Normal file
398
torch_vacc/contrib/transfer_to_vacc.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
import logging as logger
|
||||||
|
from functools import wraps
|
||||||
|
import torch
|
||||||
|
import torch_vacc
|
||||||
|
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
import torchair
|
||||||
|
except ImportError:
|
||||||
|
IS_TORCHAIR_INSTALLED = False
|
||||||
|
else:
|
||||||
|
IS_TORCHAIR_INSTALLED = True
|
||||||
|
'''
|
||||||
|
|
||||||
|
warnings.filterwarnings(action="once")
|
||||||
|
|
||||||
|
|
||||||
|
torch_fn_white_list = [
|
||||||
|
"_cudnn_init_dropout_state",
|
||||||
|
"_empty_affine_quantized",
|
||||||
|
"_empty_per_channel_affine_quantized",
|
||||||
|
"_pin_memory",
|
||||||
|
"_sparse_coo_tensor_unsafe",
|
||||||
|
"_sparse_csr_tensor_unsafe",
|
||||||
|
"logspace",
|
||||||
|
"randint",
|
||||||
|
"hann_window",
|
||||||
|
"rand",
|
||||||
|
"full_like",
|
||||||
|
"ones_like",
|
||||||
|
"rand_like",
|
||||||
|
"randperm",
|
||||||
|
"arange",
|
||||||
|
"frombuffer",
|
||||||
|
"normal",
|
||||||
|
"empty_strided",
|
||||||
|
"empty_like",
|
||||||
|
"scalar_tensor",
|
||||||
|
"tril_indices",
|
||||||
|
"bartlett_window",
|
||||||
|
"ones",
|
||||||
|
"sparse_coo_tensor",
|
||||||
|
"randn",
|
||||||
|
"kaiser_window",
|
||||||
|
"tensor",
|
||||||
|
"triu_indices",
|
||||||
|
"as_tensor",
|
||||||
|
"zeros",
|
||||||
|
"randint_like",
|
||||||
|
"full",
|
||||||
|
"eye",
|
||||||
|
"empty",
|
||||||
|
"blackman_window",
|
||||||
|
"zeros_like",
|
||||||
|
"range",
|
||||||
|
"sparse_csr_tensor",
|
||||||
|
"randn_like",
|
||||||
|
"from_file",
|
||||||
|
"linspace",
|
||||||
|
"hamming_window",
|
||||||
|
"empty_quantized",
|
||||||
|
"autocast",
|
||||||
|
"load",
|
||||||
|
]
|
||||||
|
torch_tensor_fn_white_list = [
|
||||||
|
"new_empty",
|
||||||
|
"new_empty_strided",
|
||||||
|
"new_full",
|
||||||
|
"new_ones",
|
||||||
|
"new_tensor",
|
||||||
|
"new_zeros",
|
||||||
|
"to",
|
||||||
|
]
|
||||||
|
torch_module_fn_white_list = ["to", "to_empty"]
|
||||||
|
torch_cuda_fn_white_list = [
|
||||||
|
"get_device_properties",
|
||||||
|
"get_device_name",
|
||||||
|
"get_device_capability",
|
||||||
|
"list_gpu_processes",
|
||||||
|
"set_device",
|
||||||
|
"synchronize",
|
||||||
|
"mem_get_info",
|
||||||
|
"memory_stats",
|
||||||
|
"memory_summary",
|
||||||
|
"memory_allocated",
|
||||||
|
"max_memory_allocated",
|
||||||
|
"reset_max_memory_allocated",
|
||||||
|
"memory_reserved",
|
||||||
|
"max_memory_reserved",
|
||||||
|
"reset_max_memory_cached",
|
||||||
|
"reset_peak_memory_stats",
|
||||||
|
"current_stream",
|
||||||
|
"default_stream",
|
||||||
|
]
|
||||||
|
torch_profiler_fn_white_list = ["profile"]
|
||||||
|
torch_distributed_fn_white_list = ["__init__"]
|
||||||
|
device_kwargs_list = ["device", "device_type", "map_location"]
|
||||||
|
|
||||||
|
|
||||||
|
def wrapper_cuda(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
replace_int = fn.__name__ in ["to", "to_empty"]
|
||||||
|
if args:
|
||||||
|
args_new = list(args)
|
||||||
|
args = replace_cuda_to_vacc_in_list(args_new, replace_int)
|
||||||
|
if kwargs:
|
||||||
|
for device_arg in device_kwargs_list:
|
||||||
|
device = kwargs.get(device_arg, None)
|
||||||
|
if device is not None:
|
||||||
|
replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device)
|
||||||
|
device_ids = kwargs.get("device_ids", None)
|
||||||
|
if type(device_ids) == list:
|
||||||
|
device_ids = replace_cuda_to_vacc_in_list(device_ids, replace_int)
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device):
|
||||||
|
if type(device) == str and "cuda" in device:
|
||||||
|
kwargs[device_arg] = device.replace("cuda", "vacc")
|
||||||
|
elif type(device) == torch.device and "cuda" in device.type:
|
||||||
|
device_info = (
|
||||||
|
"vacc:{}".format(device.index) if device.index is not None else "vacc"
|
||||||
|
)
|
||||||
|
kwargs[device_arg] = torch.device(device_info)
|
||||||
|
elif type(device) == int:
|
||||||
|
kwargs[device_arg] = f"vacc:{device}"
|
||||||
|
elif type(device) == dict:
|
||||||
|
kwargs[device_arg] = replace_cuda_to_vacc_in_dict(device)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_cuda_to_vacc_in_list(args_list, replace_int):
|
||||||
|
for idx, arg in enumerate(args_list):
|
||||||
|
if isinstance(arg, str) and "cuda" in arg:
|
||||||
|
args_list[idx] = arg.replace("cuda", "vacc")
|
||||||
|
elif isinstance(arg, torch.device) and "cuda" in arg.type:
|
||||||
|
device_info = (
|
||||||
|
"vacc:{}".format(arg.index) if arg.index is not None else "vacc"
|
||||||
|
)
|
||||||
|
args_list[idx] = torch.device(device_info)
|
||||||
|
elif replace_int and not isinstance(arg, bool) and isinstance(arg, int):
|
||||||
|
args_list[idx] = f"vacc:{arg}"
|
||||||
|
elif isinstance(arg, dict):
|
||||||
|
args_list[idx] = replace_cuda_to_vacc_in_dict(arg)
|
||||||
|
return args_list
|
||||||
|
|
||||||
|
|
||||||
|
def replace_cuda_to_vacc_in_dict(device_dict):
|
||||||
|
new_dict = {}
|
||||||
|
for key, value in device_dict.items():
|
||||||
|
if isinstance(key, str):
|
||||||
|
key = key.replace("cuda", "vacc")
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = value.replace("cuda", "vacc")
|
||||||
|
new_dict[key] = value
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
def device_wrapper(enter_fn, white_list):
|
||||||
|
for fn_name in white_list:
|
||||||
|
fn = getattr(enter_fn, fn_name, None)
|
||||||
|
if fn:
|
||||||
|
setattr(enter_fn, fn_name, wrapper_cuda(fn))
|
||||||
|
|
||||||
|
|
||||||
|
def wrapper_vccl(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
if args:
|
||||||
|
args_new = list(args)
|
||||||
|
for idx, arg in enumerate(args_new):
|
||||||
|
if type(arg) == str and "nccl" in arg:
|
||||||
|
args_new[idx] = arg.replace("nccl", "vccl")
|
||||||
|
args = args_new
|
||||||
|
if kwargs:
|
||||||
|
if type(kwargs.get("backend", None)) == str:
|
||||||
|
kwargs["backend"] = "vccl"
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def wrapper_data_loader(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
if kwargs:
|
||||||
|
pin_memory = kwargs.get("pin_memory", False)
|
||||||
|
pin_memory_device = kwargs.get("pin_memory_device", None)
|
||||||
|
if pin_memory and not pin_memory_device:
|
||||||
|
kwargs["pin_memory_device"] = "vacc"
|
||||||
|
if (
|
||||||
|
pin_memory
|
||||||
|
and type(pin_memory_device) == str
|
||||||
|
and "cuda" in pin_memory_device
|
||||||
|
):
|
||||||
|
kwargs["pin_memory_device"] = pin_memory_device.replace("cuda", "vacc")
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
def wrapper_get_available_device_type(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
if (torch.vacc.is_available()):
|
||||||
|
return 'vacc'
|
||||||
|
except Exception as e:
|
||||||
|
msg = "vacc device is not available."
|
||||||
|
warnings.warn(msg, RuntimeWarning)
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
'''
|
||||||
|
def wrapper_profiler(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
if kwargs:
|
||||||
|
if (
|
||||||
|
"experimental_config" in kwargs.keys()
|
||||||
|
and type(kwargs.get("experimental_config"))
|
||||||
|
!= torch_vacc.profiler._ExperimentalConfig
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"The parameter experimental_config of torch.profiler.profile has been deleted by the tool "
|
||||||
|
"because it can only be used in cuda, please manually modify the code "
|
||||||
|
"and use the experimental_config parameter adapted to vacc."
|
||||||
|
)
|
||||||
|
del kwargs["experimental_config"]
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def wrapper_compile(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
vacc_backend = torchair.get_vacc_backend()
|
||||||
|
if kwargs:
|
||||||
|
backend = kwargs.get("backend", None)
|
||||||
|
if (
|
||||||
|
not backend
|
||||||
|
or not isinstance(backend, functools.partial)
|
||||||
|
or not isinstance(backend.func, type(vacc_backend.func))
|
||||||
|
):
|
||||||
|
kwargs["backend"] = vacc_backend
|
||||||
|
else:
|
||||||
|
kwargs["backend"] = vacc_backend
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
def jit_script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None):
|
||||||
|
msg = "torch.jit.script will be disabled by transfer_to_vacc, which currently does not support it."
|
||||||
|
warnings.warn(msg, RuntimeWarning)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def patch_cuda():
|
||||||
|
patchs = [
|
||||||
|
["cuda", torch_vacc.vacc],
|
||||||
|
["cuda.amp", torch_vacc.vacc.amp],
|
||||||
|
["cuda.random", torch_vacc.vacc.random],
|
||||||
|
["cuda.amp.autocast_mode", torch_vacc.vacc.amp.autocast_mode],
|
||||||
|
["cuda.amp.common", torch_vacc.vacc.amp.common],
|
||||||
|
["cuda.amp.grad_scaler", torch_vacc.vacc.amp.grad_scaler],
|
||||||
|
]
|
||||||
|
torch_vacc._apply_patches(patchs)
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
def patch_profiler():
|
||||||
|
patchs = [
|
||||||
|
["profiler.profile", torch_vacc.profiler.profile],
|
||||||
|
["profiler.schedule", torch_vacc.profiler.schedule],
|
||||||
|
[
|
||||||
|
"profiler.tensorboard_trace_handler",
|
||||||
|
torch_vacc.profiler.tensorboard_trace_handler,
|
||||||
|
],
|
||||||
|
["profiler.ProfilerAction", torch_vacc.profiler.ProfilerAction],
|
||||||
|
["profiler.ProfilerActivity.CUDA", torch_vacc.profiler.ProfilerActivity.VACC],
|
||||||
|
["profiler.ProfilerActivity.CPU", torch_vacc.profiler.ProfilerActivity.CPU],
|
||||||
|
]
|
||||||
|
torch_vacc._apply_patches(patchs)
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
def warning_fn(msg, rank0=True):
|
||||||
|
is_distributed = (
|
||||||
|
torch.distributed.is_available()
|
||||||
|
and torch.distributed.is_initialized()
|
||||||
|
and torch.distributed.get_world_size() > 1
|
||||||
|
)
|
||||||
|
env_rank = os.getenv("RANK", None)
|
||||||
|
|
||||||
|
if rank0 and is_distributed:
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
warnings.warn(msg, ImportWarning)
|
||||||
|
elif rank0 and env_rank:
|
||||||
|
if env_rank == "0":
|
||||||
|
warnings.warn(msg, ImportWarning)
|
||||||
|
else:
|
||||||
|
warnings.warn(msg, ImportWarning)
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
warning_fn(
|
||||||
|
"""
|
||||||
|
*************************************************************************************************************
|
||||||
|
The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.vacc and torch.nn.Module.vacc now..
|
||||||
|
The torch.cuda.DoubleTensor is replaced with torch.vacc.FloatTensor cause the double type is not supported now..
|
||||||
|
The backend in torch.distributed.init_process_group set to vccl now..
|
||||||
|
The torch.cuda.* and torch.cuda.amp.* are replaced with torch.vacc.* and torch.vacc.amp.* now..
|
||||||
|
The device parameters have been replaced with vacc in the function below:
|
||||||
|
{}
|
||||||
|
If you notices any functions you use is not included in the above list, feel free to contact torch-vacc development team.
|
||||||
|
*************************************************************************************************************
|
||||||
|
""".format(
|
||||||
|
", ".join(
|
||||||
|
["torch." + i for i in torch_fn_white_list]
|
||||||
|
+ ["torch.Tensor." + i for i in torch_tensor_fn_white_list]
|
||||||
|
+ ["torch.nn.Module." + i for i in torch_module_fn_white_list]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# torch.cuda.*
|
||||||
|
patch_cuda()
|
||||||
|
device_wrapper(torch.cuda, torch_cuda_fn_white_list)
|
||||||
|
|
||||||
|
# torch.profiler.*
|
||||||
|
# TODO(qingsong): profiler not implemented yet
|
||||||
|
# patch_profiler()
|
||||||
|
# device_wrapper(torch.profiler, torch_profiler_fn_white_list)
|
||||||
|
|
||||||
|
# torch.*
|
||||||
|
device_wrapper(torch, torch_fn_white_list)
|
||||||
|
|
||||||
|
# torch.Tensor.*
|
||||||
|
device_wrapper(torch.Tensor, torch_tensor_fn_white_list)
|
||||||
|
torch.Tensor.cuda = torch.Tensor.vacc
|
||||||
|
torch.Tensor.is_cuda = torch.Tensor.is_vacc
|
||||||
|
|
||||||
|
for dtype_tensor in [
|
||||||
|
"ByteTensor",
|
||||||
|
"CharTensor",
|
||||||
|
"DoubleTensor",
|
||||||
|
"FloatTensor",
|
||||||
|
"IntTensor",
|
||||||
|
"LongTensor",
|
||||||
|
"ShortTensor",
|
||||||
|
"HalfTensor",
|
||||||
|
"BoolTensor",
|
||||||
|
]:
|
||||||
|
setattr(
|
||||||
|
torch.cuda,
|
||||||
|
dtype_tensor,
|
||||||
|
getattr(torch.vacc, dtype_tensor),
|
||||||
|
)
|
||||||
|
# TODO(qingsong): do we need this? should we add LongTensor=IntTensor?
|
||||||
|
torch.cuda.DoubleTensor = torch.vacc.FloatTensor
|
||||||
|
|
||||||
|
# torch.nn.Module.*
|
||||||
|
device_wrapper(torch.nn.Module, torch_module_fn_white_list)
|
||||||
|
torch.nn.Module.cuda = torch.nn.Module.vacc
|
||||||
|
|
||||||
|
# torch.distributed.init_process_group
|
||||||
|
torch.distributed.init_process_group = wrapper_vccl(
|
||||||
|
torch.distributed.init_process_group
|
||||||
|
)
|
||||||
|
torch.distributed.is_nccl_available = torch.distributed.is_vccl_available
|
||||||
|
|
||||||
|
# torch.nn.parallel.DistributedDataParallel
|
||||||
|
device_wrapper(
|
||||||
|
torch.nn.parallel.DistributedDataParallel, torch_distributed_fn_white_list
|
||||||
|
)
|
||||||
|
# torch.utils.data.DataLoader
|
||||||
|
torch.utils.data.DataLoader.__init__ = wrapper_data_loader(
|
||||||
|
torch.utils.data.DataLoader.__init__
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.jit.script = jit_script
|
||||||
|
torch._utils._get_available_device_type = wrapper_get_available_device_type(
|
||||||
|
torch._utils._get_available_device_type
|
||||||
|
)
|
||||||
|
|
||||||
|
'''
|
||||||
|
if IS_TORCHAIR_INSTALLED:
|
||||||
|
torch.compile = wrapper_compile(torch.compile)
|
||||||
|
'''
|
||||||
|
|
||||||
|
init()
|
||||||
0
torch_vacc/fused_ops/__init__.py
Normal file
0
torch_vacc/fused_ops/__init__.py
Normal file
BIN
torch_vacc/fused_ops/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/fused_ops/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/fused_ops/__pycache__/rms_norm.cpython-312.pyc
Normal file
BIN
torch_vacc/fused_ops/__pycache__/rms_norm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/fused_ops/__pycache__/rope_emb.cpython-312.pyc
Normal file
BIN
torch_vacc/fused_ops/__pycache__/rope_emb.cpython-312.pyc
Normal file
Binary file not shown.
42
torch_vacc/fused_ops/rms_norm.py
Normal file
42
torch_vacc/fused_ops/rms_norm.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
import torch_vacc
|
||||||
|
from torch_vacc._vacc_libs import _torch_vacc
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRMSNormFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
|
||||||
|
output, rsigma, var = torch.ops.vacc.rms_norm_forward(input, weight, eps)
|
||||||
|
ctx.save_for_backward(input, weight, rsigma, var)
|
||||||
|
ctx.eps = eps
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output: torch.Tensor):
|
||||||
|
input, weight, rsigma, var = ctx.saved_tensors
|
||||||
|
grad_input, grad_weight = _torch_vacc.rms_norm_backward(
|
||||||
|
grad_output, input, weight, rsigma, var, ctx.eps
|
||||||
|
)
|
||||||
|
return grad_input, grad_weight, None
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
|
||||||
|
return FusedRMSNormFunction.apply(input, weight, eps)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps: float = 1e-6):
|
||||||
|
super(FusedRMSNorm, self).__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
||||||
|
output = FusedRMSNormFunction.apply(hidden_states, self.weight, self.eps)
|
||||||
|
|
||||||
|
output = output.to(dtype)
|
||||||
|
return output
|
||||||
32
torch_vacc/fused_ops/rope_emb.py
Normal file
32
torch_vacc/fused_ops/rope_emb.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
import torch_vacc
|
||||||
|
from torch_vacc._vacc_libs import _torch_vacc
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRopeEmbFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, q: torch.Tensor, k: torch.Tensor, offset: int):
|
||||||
|
qemb, kemb = _torch_vacc.rope_forward(q, k, offset)
|
||||||
|
ctx.offset = offset
|
||||||
|
|
||||||
|
return qemb, kemb
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, q_out_grad: torch.Tensor, k_out_grad: torch.Tensor):
|
||||||
|
grad_input, grad_rope = _torch_vacc.rope_backward(
|
||||||
|
q_out_grad, k_out_grad, ctx.offset
|
||||||
|
)
|
||||||
|
return grad_input, grad_rope, None
|
||||||
|
|
||||||
|
|
||||||
|
def rope_emb(q: torch.Tensor, k: torch.Tensor, offset: int):
|
||||||
|
# return FusedRopeEmbFunction.apply(q, k, offset)
|
||||||
|
return torch_vacc.vacc.custom_ops.RotaryPosEmbedding(q=q, k=k, offset=offset)
|
||||||
|
|
||||||
|
|
||||||
|
class RopeEmb(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int):
|
||||||
|
return rope_emb(q, k, offset)
|
||||||
62
torch_vacc/testing/__init__.py
Normal file
62
torch_vacc/testing/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
from torch.testing import make_tensor
|
||||||
|
from functools import partial, wraps
|
||||||
|
import torch.testing._internal.common_device_type as cdt
|
||||||
|
from torch.testing._internal.common_device_type import (
|
||||||
|
DeviceTypeTestBase,
|
||||||
|
dtypes,
|
||||||
|
instantiate_device_type_tests,
|
||||||
|
onlyOn,
|
||||||
|
onlyPRIVATEUSE1,
|
||||||
|
ops,
|
||||||
|
)
|
||||||
|
|
||||||
|
if sys.version_info > (3, 8):
|
||||||
|
from torch.testing._internal.common_distributed import (
|
||||||
|
MultiProcessTestCase,
|
||||||
|
init_multigpu_helper,
|
||||||
|
skip_if_lt_x_gpu,
|
||||||
|
get_timeout,
|
||||||
|
#skip_if_rocm,
|
||||||
|
with_dist_debug_levels,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from torch.testing._internal.common_distributed import (
|
||||||
|
MultiProcessTestCase,
|
||||||
|
init_multigpu_helper,
|
||||||
|
skip_if_lt_x_gpu,
|
||||||
|
get_timeout,
|
||||||
|
skip_if_rocm,
|
||||||
|
with_dist_debug_levels,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
TestCase,
|
||||||
|
load_tests,
|
||||||
|
parametrize,
|
||||||
|
run_tests,
|
||||||
|
subtest,
|
||||||
|
retry_on_connect_failures,
|
||||||
|
instantiate_parametrized_tests,
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
onlyVacc = onlyPRIVATEUSE1
|
||||||
|
|
||||||
|
|
||||||
|
class VaccTestBase(DeviceTypeTestBase):
|
||||||
|
device_type = "vacc"
|
||||||
|
|
||||||
|
|
||||||
|
if VaccTestBase not in cdt.device_type_test_bases:
|
||||||
|
cdt.device_type_test_bases.append(VaccTestBase)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def freeze_rng_state():
|
||||||
|
rng_state = torch.get_rng_state()
|
||||||
|
yield
|
||||||
|
torch.set_rng_state(rng_state)
|
||||||
BIN
torch_vacc/testing/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/testing/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/testing/__pycache__/summarize_report.cpython-312.pyc
Normal file
BIN
torch_vacc/testing/__pycache__/summarize_report.cpython-312.pyc
Normal file
Binary file not shown.
103
torch_vacc/testing/summarize_report.py
Normal file
103
torch_vacc/testing/summarize_report.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Tool to summarize unit test XML reports, it summarize
|
||||||
|
* number of tests, and failure/error/skipped
|
||||||
|
* top 10 slowest tests
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m torch_vacc.testing.summarize_report --report report.xml
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from xml.etree import ElementTree as ET
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from torch_vacc import set_global_log_level
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--report", type=str)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_testsuites(suites):
|
||||||
|
summary = {
|
||||||
|
"errors": int,
|
||||||
|
"failures": int,
|
||||||
|
"skipped": int,
|
||||||
|
"skips": int,
|
||||||
|
"tests": int,
|
||||||
|
"time": float,
|
||||||
|
}
|
||||||
|
|
||||||
|
attribs = [s.attrib for s in suites]
|
||||||
|
for key in summary:
|
||||||
|
summary[key] = sum(summary[key](a[key]) for a in attribs if key in a)
|
||||||
|
assert not (summary["skipped"] and summary["skips"])
|
||||||
|
if summary["skips"]:
|
||||||
|
summary["skipped"] = summary["skips"]
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def format_summary(summary):
|
||||||
|
template = "Ran {tests} tests in {time:.3f}s (errors={errors}, failures={failures}, skipped={skipped})"
|
||||||
|
msg = template.format(**summary)
|
||||||
|
if summary["errors"] > 0 or summary["failures"] > 0:
|
||||||
|
msg = "FAILED. " + msg
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestCaseInfo:
|
||||||
|
test_class_name: str
|
||||||
|
test_name: str
|
||||||
|
time: float
|
||||||
|
timestamp: str
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
return self.time < other.time
|
||||||
|
|
||||||
|
|
||||||
|
def sort_cases_by_time(suites):
|
||||||
|
test_cases = [
|
||||||
|
TestCaseInfo(
|
||||||
|
s.attrib["classname"],
|
||||||
|
s.attrib["name"],
|
||||||
|
s.attrib["time"],
|
||||||
|
s.attrib["timestamp"],
|
||||||
|
s.attrib.get("failure") is None,
|
||||||
|
)
|
||||||
|
for s in suites
|
||||||
|
]
|
||||||
|
test_cases.sort(reverse=True)
|
||||||
|
return test_cases
|
||||||
|
|
||||||
|
|
||||||
|
def read_report(fpath):
|
||||||
|
with open(fpath) as report:
|
||||||
|
try:
|
||||||
|
report = ET.parse(report)
|
||||||
|
except ET.ParseError:
|
||||||
|
print(f"{sys.argv[0]}: Cannot parse file {fpath}", file=sys.stderr)
|
||||||
|
return
|
||||||
|
root = report.getroot()
|
||||||
|
suites = [root] if root.tag == "testsuite" else root.findall("testsuite")
|
||||||
|
summary = summarize_testsuites(suites)
|
||||||
|
summary_msg = format_summary(summary)
|
||||||
|
print(summary_msg)
|
||||||
|
|
||||||
|
for suite in suites:
|
||||||
|
cases = sort_cases_by_time(suite.findall("testcase"))
|
||||||
|
[print(case) for case in cases[:10]]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
set_global_log_level("ERROR")
|
||||||
|
args = parse_args()
|
||||||
|
read_report(args.report)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
184
torch_vacc/vacc/__init__.py
Normal file
184
torch_vacc/vacc/__init__.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ._device import (
|
||||||
|
current_device,
|
||||||
|
device,
|
||||||
|
device_count,
|
||||||
|
get_device_capability,
|
||||||
|
get_device_name,
|
||||||
|
get_device_properties,
|
||||||
|
is_available,
|
||||||
|
is_bf16_supported,
|
||||||
|
set_device,
|
||||||
|
synchronize,
|
||||||
|
)
|
||||||
|
from .amp import (
|
||||||
|
get_amp_supported_dtype,
|
||||||
|
get_autocast_dtype,
|
||||||
|
is_autocast_enabled,
|
||||||
|
set_autocast_dtype,
|
||||||
|
set_autocast_enabled,
|
||||||
|
)
|
||||||
|
from .lazy_initialize import _is_in_bad_fork, _lazy_call, _lazy_init
|
||||||
|
from .memory import ( # caching_allocator_alloc,; caching_allocator_delete,
|
||||||
|
empty_cache,
|
||||||
|
get_allocator_backend,
|
||||||
|
max_memory_allocated,
|
||||||
|
max_memory_cached,
|
||||||
|
max_memory_reserved,
|
||||||
|
mem_get_info,
|
||||||
|
memory_allocated,
|
||||||
|
memory_cached,
|
||||||
|
memory_reserved,
|
||||||
|
memory_snapshot,
|
||||||
|
memory_stats,
|
||||||
|
memory_stats_as_nested_dict,
|
||||||
|
memory_summary,
|
||||||
|
reset_accumulated_memory_stats,
|
||||||
|
reset_max_memory_allocated,
|
||||||
|
reset_max_memory_cached,
|
||||||
|
reset_peak_memory_stats,
|
||||||
|
set_per_process_memory_fraction,
|
||||||
|
)
|
||||||
|
from .streams import Event, Stream, current_stream, default_stream, set_stream, stream
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
r"""Initialize PyTorch's VACC state. You may need to call
|
||||||
|
this explicitly if you are interacting with PyTorch via
|
||||||
|
its C API, as Python bindings for VACC functionality will not
|
||||||
|
be available until this initialization takes place. Ordinary users
|
||||||
|
should not need this, as all of PyTorch's VACC methods
|
||||||
|
automatically initialize VACC state on-demand.
|
||||||
|
|
||||||
|
Does nothing if the VACC state is already initialized.
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
|
||||||
|
|
||||||
|
# default_generators is empty util _lazy_init() is called
|
||||||
|
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||||||
|
|
||||||
|
from .custom_ops import *
|
||||||
|
from .custom_qwen3_ops import *
|
||||||
|
from .random import * # noqa: F403
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"device",
|
||||||
|
"is_available",
|
||||||
|
"is_bf16_supported",
|
||||||
|
"current_device",
|
||||||
|
"set_device",
|
||||||
|
"device_count",
|
||||||
|
"get_device_properties",
|
||||||
|
"get_device_name",
|
||||||
|
"get_device_capability",
|
||||||
|
"synchronize",
|
||||||
|
"amp",
|
||||||
|
"get_amp_supported_dtype",
|
||||||
|
"is_autocast_enabled",
|
||||||
|
"set_autocast_enabled",
|
||||||
|
"get_autocast_dtype",
|
||||||
|
"set_autocast_dtype",
|
||||||
|
"_is_in_bad_fork",
|
||||||
|
"_lazy_call",
|
||||||
|
"get_rng_state",
|
||||||
|
"get_rng_state_all",
|
||||||
|
"set_rng_state",
|
||||||
|
"set_rng_state_all",
|
||||||
|
"manual_seed",
|
||||||
|
"manual_seed_all",
|
||||||
|
"seed",
|
||||||
|
"seed_all",
|
||||||
|
"initial_seed",
|
||||||
|
"set_stream",
|
||||||
|
"current_stream",
|
||||||
|
"default_generators",
|
||||||
|
"default_stream",
|
||||||
|
"stream",
|
||||||
|
"Stream",
|
||||||
|
"Event",
|
||||||
|
"mem_get_info",
|
||||||
|
"set_per_process_memory_fraction",
|
||||||
|
"empty_cache",
|
||||||
|
"memory_stats",
|
||||||
|
"memory_stats_as_nested_dict",
|
||||||
|
"reset_accumulated_memory_stats",
|
||||||
|
"reset_peak_memory_stats",
|
||||||
|
"reset_max_memory_allocated",
|
||||||
|
"reset_max_memory_cached",
|
||||||
|
"memory_allocated",
|
||||||
|
"max_memory_allocated",
|
||||||
|
"memory_reserved",
|
||||||
|
"max_memory_reserved",
|
||||||
|
"memory_cached",
|
||||||
|
"max_memory_cached",
|
||||||
|
"memory_snapshot",
|
||||||
|
"memory_summary",
|
||||||
|
"get_allocator_backend",
|
||||||
|
"rms_norm",
|
||||||
|
"RotaryPosEmbedding",
|
||||||
|
"scaled_dot_product_attention",
|
||||||
|
"scaled_dot_product_attention_cp_forward",
|
||||||
|
"scaled_dot_product_attention_cp_backward",
|
||||||
|
"swiglu",
|
||||||
|
"paged_attention",
|
||||||
|
"reshape_and_cache_attention",
|
||||||
|
"concat_and_cache_attention",
|
||||||
|
"w8a8_block_fp8_matmul",
|
||||||
|
"moe_expert_token_group_reassign",
|
||||||
|
"fused_mlp_mm_fp8",
|
||||||
|
"fused_mlp_fp8",
|
||||||
|
"fused_moe_preprocess",
|
||||||
|
"fused_residual_rmsnorm",
|
||||||
|
"parallel_embedding",
|
||||||
|
"all_reduce",
|
||||||
|
"all_gather",
|
||||||
|
"broadcast",
|
||||||
|
"fused_mlp_moe_with_rmsnorm",
|
||||||
|
"fuse_moe_decode_v2_allreduce",
|
||||||
|
"topk_topp",
|
||||||
|
"fused_mla",
|
||||||
|
"fused_mla_allreduce",
|
||||||
|
"fused_mlp_with_rmsnorm",
|
||||||
|
"fused_mlp_allreduce",
|
||||||
|
"ds3_sampler",
|
||||||
|
"sampler_v1",
|
||||||
|
"rejection_sampler",
|
||||||
|
"rejection_sampler_update_hidden_states",
|
||||||
|
"rejection_sampler_v1",
|
||||||
|
"fused_matmul_allgather",
|
||||||
|
"fused_mla_v2",
|
||||||
|
"fused_mla_allreduce_v2",
|
||||||
|
"mla_matmul_scale",
|
||||||
|
"mla_matmul",
|
||||||
|
"fused_mla_prefill_stage0",
|
||||||
|
"fused_mla_prefill_stage1",
|
||||||
|
"fused_mla_prefill_stage0_allreduce",
|
||||||
|
"fuse_moe_prefill_stage0",
|
||||||
|
"fuse_mla_mlp_v2_allreduce_decode",
|
||||||
|
"fuse_mla_moe_v2_allreduce_decode",
|
||||||
|
"fuse_mla_mlp_v2_allreduce_decode_layers",
|
||||||
|
"fuse_mla_moe_v2_allreduce_decode_layers",
|
||||||
|
"fuse_mla_mlp_v2_allreduce_decode_layers_v2",
|
||||||
|
"fuse_mla_moe_v2_allreduce_decode_layers_v2",
|
||||||
|
"fuse_mlp_qwen_int4",
|
||||||
|
"fuse_mlp_qwen_int4_reduce",
|
||||||
|
"w4a8_block_int4_matmul",
|
||||||
|
"fuse_atten_qwen3",
|
||||||
|
"fuse_atten_qwen2",
|
||||||
|
"qwen3_fuse_attention_moe_decode",
|
||||||
|
"fuse_mtp_stage0",
|
||||||
|
"fuse_mtp_allreduce",
|
||||||
|
"roll_out",
|
||||||
|
"fused_experts_int4_prefill",
|
||||||
|
"fuse_bge_embedding_stage1",
|
||||||
|
"l2_norm",
|
||||||
|
"fuse_mlp_vision",
|
||||||
|
"patch_merger_vision",
|
||||||
|
"fuse_atten_vit",
|
||||||
|
"apply_penalties",
|
||||||
|
]
|
||||||
BIN
torch_vacc/vacc/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/_device.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/_device.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/custom_ops.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/custom_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/custom_ops_cpu.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/custom_ops_cpu.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/custom_qwen3_ops.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/custom_qwen3_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/lazy_initialize.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/lazy_initialize.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/memory.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/memory.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/random.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/random.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/__pycache__/streams.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/__pycache__/streams.cpython-312.pyc
Normal file
Binary file not shown.
106
torch_vacc/vacc/_device.py
Normal file
106
torch_vacc/vacc/_device.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# Device information
|
||||||
|
# replacing `torch.cuda.func`` with `torch_vacc.vacc.func`.
|
||||||
|
# see https://pytorch.org/docs/stable/cuda.html
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_vacc
|
||||||
|
from torch._utils import _get_device_index
|
||||||
|
from torch_vacc._vacc_libs import _torch_vacc
|
||||||
|
|
||||||
|
from .lazy_initialize import _lazy_init
|
||||||
|
|
||||||
|
if hasattr(_torch_vacc, "_exchange_device"):
|
||||||
|
_exchange_device = _torch_vacc._exchange_device
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _exchange_device(device: int) -> int:
|
||||||
|
return _torch_vacc._exchange_device()
|
||||||
|
if device < 0:
|
||||||
|
return -1
|
||||||
|
prev_device = current_device()
|
||||||
|
if device != prev_device:
|
||||||
|
set_device(device)
|
||||||
|
return prev_device
|
||||||
|
|
||||||
|
|
||||||
|
class device(object):
|
||||||
|
"""Context-manager that changes the selected device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device or int): device index to select. It's a no-op if
|
||||||
|
this argument is a negative integer or ``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Any):
|
||||||
|
self.idx = _get_device_index(device, optional=True)
|
||||||
|
self.prev_idx = -1
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.prev_idx = _exchange_device(self.idx)
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
_exchange_device(self.prev_idx)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_available() -> bool:
|
||||||
|
r"""Returns whether vacc is available."""
|
||||||
|
return device_count() > 0
|
||||||
|
|
||||||
|
def is_bf16_supported() -> bool:
|
||||||
|
r"""Returns a bool indicating if the current vacc device supports dtype bfloat16"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def current_device() -> int:
|
||||||
|
r"""Returns the index of a currently selected vacc device."""
|
||||||
|
_lazy_init()
|
||||||
|
return _torch_vacc._current_device()
|
||||||
|
|
||||||
|
|
||||||
|
def set_device(device: torch.device):
|
||||||
|
device_index = _get_device_index(device, optional=True)
|
||||||
|
if device_index >= 0:
|
||||||
|
_torch_vacc._set_device(device_index)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_capability(device=None):
|
||||||
|
r"""Query the minor and major data of device. Cann does not
|
||||||
|
have a corresponding concept and is not supported. By default, it returns None
|
||||||
|
"""
|
||||||
|
_infos = "torch.vacc.get_device_capability isn't implemented! Please do the version check in other ways, Unlike CUDA major,min"
|
||||||
|
raise AssertionError(_infos)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_name(device_name=None):
|
||||||
|
device_id = _get_device_index(device_name, optional=True)
|
||||||
|
if device_id < 0 or device_id >= device_count():
|
||||||
|
raise AssertionError("Invalid device id")
|
||||||
|
_lazy_init()
|
||||||
|
device_prop = _torch_vacc._vacc_getDeviceProperties(device_id)
|
||||||
|
return device_prop.name
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_properties(device_name=None):
|
||||||
|
device_id = _get_device_index(device_name, optional=True)
|
||||||
|
if device_id < 0 or device_id >= device_count():
|
||||||
|
raise AssertionError("Invalid device id")
|
||||||
|
_lazy_init()
|
||||||
|
return _torch_vacc._vacc_getDeviceProperties(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def device_count():
|
||||||
|
r"""Returns the number of available vacc devices"""
|
||||||
|
return _torch_vacc._device_count()
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize(device=None) -> None:
|
||||||
|
"""Waits for all operations in all streams on a VACC device to complete."""
|
||||||
|
_lazy_init()
|
||||||
|
with torch_vacc.vacc.device(device):
|
||||||
|
return _torch_vacc._device_synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
# Memory management (https://pytorch.org/docs/stable/cuda.html#memory-management)
|
||||||
26
torch_vacc/vacc/amp/__init__.py
Normal file
26
torch_vacc/vacc/amp/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from typing import List
|
||||||
|
import torch
|
||||||
|
from torch_vacc._vacc_libs import _torch_vacc
|
||||||
|
|
||||||
|
from .grad_scaler import OptState, GradScaler
|
||||||
|
from .autocast_mode import autocast, custom_fwd, custom_bwd
|
||||||
|
|
||||||
|
|
||||||
|
def get_amp_supported_dtype() -> List[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
|
||||||
|
def is_autocast_enabled() -> bool:
|
||||||
|
return _torch_vacc.is_autocast_enabled()
|
||||||
|
|
||||||
|
|
||||||
|
def set_autocast_enabled(enable: bool):
|
||||||
|
_torch_vacc.set_autocast_enabled(enable)
|
||||||
|
|
||||||
|
|
||||||
|
def get_autocast_dtype() -> torch.dtype:
|
||||||
|
return _torch_vacc.get_autocast_dtype()
|
||||||
|
|
||||||
|
|
||||||
|
def set_autocast_dtype(dtype: torch.dtype):
|
||||||
|
return _torch_vacc.set_autocast_dtype(dtype)
|
||||||
BIN
torch_vacc/vacc/amp/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/amp/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/amp/__pycache__/autocast_mode.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/amp/__pycache__/autocast_mode.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/amp/__pycache__/common.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/amp/__pycache__/common.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/vacc/amp/__pycache__/grad_scaler.cpython-312.pyc
Normal file
BIN
torch_vacc/vacc/amp/__pycache__/grad_scaler.cpython-312.pyc
Normal file
Binary file not shown.
144
torch_vacc/vacc/amp/autocast_mode.py
Normal file
144
torch_vacc/vacc/amp/autocast_mode.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import collections
|
||||||
|
import functools
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
HAS_NUMPY = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
np = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
|
||||||
|
|
||||||
|
|
||||||
|
class autocast(torch.amp.autocast_mode.autocast):
|
||||||
|
r"""See :class:`torch.autocast`.
|
||||||
|
|
||||||
|
``torch.vacc.amp.autocast(args...)`` is equivalent to ``torch.autocast("vacc", args...)``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool = True,
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
cache_enabled: bool = True,
|
||||||
|
):
|
||||||
|
if torch._jit_internal.is_scripting():
|
||||||
|
self._enabled = enabled
|
||||||
|
self.device = "vacc"
|
||||||
|
self.fast_dtype = dtype
|
||||||
|
return
|
||||||
|
super().__init__(
|
||||||
|
"vacc", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if torch._jit_internal.is_scripting():
|
||||||
|
return self
|
||||||
|
return super().__enter__()
|
||||||
|
|
||||||
|
# TODO: discuss a unified TorchScript-friendly API for autocast
|
||||||
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
|
||||||
|
if torch._jit_internal.is_scripting():
|
||||||
|
return
|
||||||
|
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
def __call__(self, func):
|
||||||
|
if torch._jit_internal.is_scripting():
|
||||||
|
return func
|
||||||
|
return super().__call__(func)
|
||||||
|
|
||||||
|
|
||||||
|
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
|
||||||
|
# may be falsely detected as "Iterables."
|
||||||
|
def _cast(value, dtype):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
is_eligible = (
|
||||||
|
value.is_floating_point()
|
||||||
|
and value.is_vacc
|
||||||
|
and (value.dtype is not torch.float64)
|
||||||
|
)
|
||||||
|
return value.to(dtype) if is_eligible else value
|
||||||
|
elif isinstance(value, (str, bytes)):
|
||||||
|
return value
|
||||||
|
elif HAS_NUMPY and isinstance(value, np.ndarray):
|
||||||
|
return value
|
||||||
|
elif isinstance(value, collections.abc.Mapping):
|
||||||
|
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
|
||||||
|
elif isinstance(value, collections.abc.Iterable):
|
||||||
|
iterable = (_cast(v, dtype) for v in value)
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
return type(value)(iterable)
|
||||||
|
else:
|
||||||
|
return iterable
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# custom_fwd is a decorator that may or may not be used with arguments, following
|
||||||
|
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
|
||||||
|
# this works:
|
||||||
|
# @custom_fwd
|
||||||
|
# def forward(...):
|
||||||
|
# this also works:
|
||||||
|
# @custom_fwd(cast_inputs=torch.float)
|
||||||
|
# def forward(...):
|
||||||
|
def custom_fwd(fwd=None, *, cast_inputs=None):
|
||||||
|
"""
|
||||||
|
Create a helper decorator for ``forward`` methods of custom autograd functions.
|
||||||
|
|
||||||
|
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||||
|
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
||||||
|
when ``forward`` runs in an autocast-enabled region, casts incoming
|
||||||
|
floating-point VACC Tensors to the target dtype (non-floating-point Tensors are not affected),
|
||||||
|
then executes ``forward`` with autocast disabled.
|
||||||
|
If ``None``, ``forward``'s internal ops execute with the current autocast state.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If the decorated ``forward`` is called outside an autocast-enabled region,
|
||||||
|
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
|
||||||
|
"""
|
||||||
|
if fwd is None:
|
||||||
|
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
|
||||||
|
|
||||||
|
@functools.wraps(fwd)
|
||||||
|
def decorate_fwd(*args, **kwargs):
|
||||||
|
args[0]._dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
if cast_inputs is None:
|
||||||
|
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
|
||||||
|
return fwd(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
autocast_context = torch.is_autocast_enabled()
|
||||||
|
args[0]._fwd_used_autocast = False
|
||||||
|
if autocast_context:
|
||||||
|
with autocast(enabled=False):
|
||||||
|
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
|
||||||
|
else:
|
||||||
|
return fwd(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorate_fwd
|
||||||
|
|
||||||
|
|
||||||
|
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
|
||||||
|
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
|
||||||
|
# cast_inputs supplied to custom_fwd.
|
||||||
|
def custom_bwd(bwd):
|
||||||
|
"""Create a helper decorator for backward methods of custom autograd functions.
|
||||||
|
|
||||||
|
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||||
|
Ensures that ``backward`` executes with the same autocast state as ``forward``.
|
||||||
|
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(bwd)
|
||||||
|
def decorate_bwd(*args, **kwargs):
|
||||||
|
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
|
||||||
|
return bwd(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorate_bwd
|
||||||
7
torch_vacc/vacc/amp/common.py
Normal file
7
torch_vacc/vacc/amp/common.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = ["amp_definitely_not_available"]
|
||||||
|
|
||||||
|
|
||||||
|
def amp_definitely_not_available():
|
||||||
|
return not torch.vacc.is_available()
|
||||||
667
torch_vacc/vacc/amp/grad_scaler.py
Normal file
667
torch_vacc/vacc/amp/grad_scaler.py
Normal file
@@ -0,0 +1,667 @@
|
|||||||
|
import inspect
|
||||||
|
import warnings
|
||||||
|
from collections import abc, defaultdict
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
from .common import amp_definitely_not_available
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["OptState", "GradScaler"]
|
||||||
|
|
||||||
|
|
||||||
|
class _MultiDeviceReplicator:
|
||||||
|
"""
|
||||||
|
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||||
|
assert (
|
||||||
|
master_tensor.is_cuda
|
||||||
|
or master_tensor.device.type == "xla"
|
||||||
|
or master_tensor.device.type == "vacc"
|
||||||
|
)
|
||||||
|
self.master = master_tensor
|
||||||
|
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||||
|
|
||||||
|
def get(self, device) -> torch.Tensor:
|
||||||
|
retval = self._per_device_tensors.get(device, None)
|
||||||
|
if retval is None:
|
||||||
|
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
||||||
|
self._per_device_tensors[device] = retval
|
||||||
|
return retval
|
||||||
|
|
||||||
|
|
||||||
|
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
|
||||||
|
# as well as associated "enum" values. Prefers defining these at top level because
|
||||||
|
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
|
||||||
|
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
|
||||||
|
# causes a circular reference, which we'd rather avoid.
|
||||||
|
class OptState(Enum):
|
||||||
|
READY = 0
|
||||||
|
UNSCALED = 1
|
||||||
|
STEPPED = 2
|
||||||
|
|
||||||
|
|
||||||
|
def _refresh_per_optimizer_state():
|
||||||
|
return {"stage": OptState.READY, "found_inf_per_device": {}}
|
||||||
|
|
||||||
|
|
||||||
|
class GradScaler:
|
||||||
|
_scale: Optional[torch.Tensor]
|
||||||
|
_grows_tracker: Optional[torch.Tensor]
|
||||||
|
_per_optimizer_states: Dict[int, Dict[str, Any]]
|
||||||
|
"""
|
||||||
|
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
|
||||||
|
conveniently.
|
||||||
|
|
||||||
|
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
|
||||||
|
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
|
||||||
|
* ``scaler.update()`` updates ``scaler``'s scale factor.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
# Creates a GradScaler once at the beginning of training.
|
||||||
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
for epoch in epochs:
|
||||||
|
for input, target in data:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(input)
|
||||||
|
loss = loss_fn(output, target)
|
||||||
|
|
||||||
|
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# scaler.step() first unscales gradients of the optimizer's params.
|
||||||
|
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
|
||||||
|
# otherwise, optimizer.step() is skipped.
|
||||||
|
scaler.step(optimizer)
|
||||||
|
|
||||||
|
# Updates the scale for next iteration.
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
|
||||||
|
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
|
||||||
|
and multiple losses/optimizers.
|
||||||
|
|
||||||
|
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
|
||||||
|
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
|
||||||
|
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
|
||||||
|
without incurring inf or NaN gradient values.
|
||||||
|
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
|
||||||
|
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
|
||||||
|
|
||||||
|
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
|
||||||
|
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
|
||||||
|
|
||||||
|
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
|
||||||
|
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
|
||||||
|
``growth_factor``.
|
||||||
|
|
||||||
|
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
|
||||||
|
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
|
||||||
|
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_scale (float, optional, default=2.**16): Initial scale factor.
|
||||||
|
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
|
||||||
|
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
|
||||||
|
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
|
||||||
|
:meth:`update` if inf/NaN gradients occur in an iteration.
|
||||||
|
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
|
||||||
|
that must occur for the scale to be multiplied by ``growth_factor``.
|
||||||
|
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
|
||||||
|
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
|
||||||
|
Default: ``True``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
init_scale=2.0**16,
|
||||||
|
growth_factor=2.0,
|
||||||
|
backoff_factor=0.5,
|
||||||
|
growth_interval=2000,
|
||||||
|
enabled=True,
|
||||||
|
):
|
||||||
|
if enabled and amp_definitely_not_available():
|
||||||
|
warnings.warn(
|
||||||
|
"torch.vacc.amp.GradScaler is enabled, but VACC device is not available. Disabling."
|
||||||
|
)
|
||||||
|
self._enabled = False
|
||||||
|
else:
|
||||||
|
self._enabled = enabled
|
||||||
|
|
||||||
|
if self._enabled:
|
||||||
|
assert growth_factor > 1.0, "The growth factor must be > 1.0."
|
||||||
|
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
|
||||||
|
|
||||||
|
self._init_scale = init_scale
|
||||||
|
# self._scale will be lazily initialized during the first call to scale()
|
||||||
|
self._scale = None
|
||||||
|
self._growth_factor = growth_factor
|
||||||
|
self._backoff_factor = backoff_factor
|
||||||
|
self._growth_interval = growth_interval
|
||||||
|
self._init_growth_tracker = 0
|
||||||
|
# self._growth_tracker will be lazily initialized during the first call to scale()
|
||||||
|
self._growth_tracker = None
|
||||||
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
|
||||||
|
def _check_scale_growth_tracker(
|
||||||
|
self, funcname
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
|
||||||
|
assert self._scale is not None, (
|
||||||
|
f"Attempted {funcname} but _scale is None. " + fix
|
||||||
|
)
|
||||||
|
assert self._growth_tracker is not None, (
|
||||||
|
f"Attempted {funcname} but _growth_tracker is None. " + fix
|
||||||
|
)
|
||||||
|
return (self._scale, self._growth_tracker)
|
||||||
|
|
||||||
|
def _lazy_init_scale_growth_tracker(self, dev):
|
||||||
|
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
|
||||||
|
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
|
||||||
|
self._growth_tracker = torch.full(
|
||||||
|
(), self._init_growth_tracker, dtype=torch.int32, device=dev
|
||||||
|
)
|
||||||
|
|
||||||
|
def scale(self, outputs):
|
||||||
|
"""
|
||||||
|
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||||
|
|
||||||
|
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||||
|
unmodified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# Short-circuit for the common case.
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
assert (
|
||||||
|
outputs.is_cuda
|
||||||
|
or outputs.device.type == "xla"
|
||||||
|
or outputs.device.type == "vacc"
|
||||||
|
)
|
||||||
|
if self._scale is None:
|
||||||
|
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||||
|
assert self._scale is not None
|
||||||
|
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
||||||
|
|
||||||
|
# Invoke the more complex machinery only if we're treating multiple outputs.
|
||||||
|
stash: List[
|
||||||
|
_MultiDeviceReplicator
|
||||||
|
] = [] # holds a reference that can be overwritten by apply_scale
|
||||||
|
|
||||||
|
def apply_scale(val):
|
||||||
|
if isinstance(val, torch.Tensor):
|
||||||
|
assert (
|
||||||
|
val.is_cuda or val.device.type == "xla" or val.device.type == "vacc"
|
||||||
|
)
|
||||||
|
if len(stash) == 0:
|
||||||
|
if self._scale is None:
|
||||||
|
self._lazy_init_scale_growth_tracker(val.device)
|
||||||
|
assert self._scale is not None
|
||||||
|
stash.append(_MultiDeviceReplicator(self._scale))
|
||||||
|
return val * stash[0].get(val.device)
|
||||||
|
elif isinstance(val, abc.Iterable):
|
||||||
|
iterable = map(apply_scale, val)
|
||||||
|
if isinstance(val, (list, tuple)):
|
||||||
|
return type(val)(iterable)
|
||||||
|
else:
|
||||||
|
return iterable
|
||||||
|
else:
|
||||||
|
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||||
|
|
||||||
|
return apply_scale(outputs)
|
||||||
|
|
||||||
|
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
|
||||||
|
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
||||||
|
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
||||||
|
|
||||||
|
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
||||||
|
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
||||||
|
# However, we don't know their devices or dtypes in advance.
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
||||||
|
# Google says mypy struggles with defaultdicts type annotations.
|
||||||
|
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
||||||
|
with torch.no_grad():
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for param in group["params"]:
|
||||||
|
if param.grad is None:
|
||||||
|
continue
|
||||||
|
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
||||||
|
raise ValueError("Attempting to unscale FP16 gradients.")
|
||||||
|
if param.grad.is_sparse:
|
||||||
|
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
||||||
|
# coalesce() deduplicates indices and adds all values that have the same index.
|
||||||
|
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
||||||
|
# so we should check the coalesced _values().
|
||||||
|
if param.grad.dtype is torch.float16:
|
||||||
|
param.grad = param.grad.coalesce()
|
||||||
|
to_unscale = param.grad._values()
|
||||||
|
else:
|
||||||
|
to_unscale = param.grad
|
||||||
|
|
||||||
|
# TODO: is there a way to split by device and dtype without appending in the inner loop?
|
||||||
|
per_device_and_dtype_grads[to_unscale.device][
|
||||||
|
to_unscale.dtype
|
||||||
|
].append(to_unscale)
|
||||||
|
|
||||||
|
for device, per_dtype_grads in per_device_and_dtype_grads.items():
|
||||||
|
for grads in per_dtype_grads.values():
|
||||||
|
torch._amp_foreach_non_finite_check_and_unscale_(
|
||||||
|
grads,
|
||||||
|
per_device_found_inf.get(device),
|
||||||
|
per_device_inv_scale.get(device),
|
||||||
|
)
|
||||||
|
|
||||||
|
return per_device_found_inf._per_device_tensors
|
||||||
|
|
||||||
|
def unscale_(self, optimizer):
|
||||||
|
"""
|
||||||
|
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
||||||
|
|
||||||
|
:meth:`unscale_` is optional, serving cases where you need to
|
||||||
|
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
||||||
|
between the backward pass(es) and :meth:`step`.
|
||||||
|
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
||||||
|
|
||||||
|
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
||||||
|
|
||||||
|
...
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
:meth:`unscale_` does not incur a CPU-GPU sync.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
||||||
|
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
||||||
|
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._check_scale_growth_tracker("unscale_")
|
||||||
|
|
||||||
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||||
|
|
||||||
|
if optimizer_state["stage"] is OptState.UNSCALED:
|
||||||
|
raise RuntimeError(
|
||||||
|
"unscale_() has already been called on this optimizer since the last update()."
|
||||||
|
)
|
||||||
|
elif optimizer_state["stage"] is OptState.STEPPED:
|
||||||
|
raise RuntimeError("unscale_() is being called after step().")
|
||||||
|
|
||||||
|
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||||
|
assert self._scale is not None
|
||||||
|
inv_scale = self._scale.double().reciprocal().float()
|
||||||
|
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
|
||||||
|
|
||||||
|
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
||||||
|
optimizer, inv_scale, found_inf, False
|
||||||
|
)
|
||||||
|
optimizer_state["stage"] = OptState.UNSCALED
|
||||||
|
|
||||||
|
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
|
||||||
|
retval = None
|
||||||
|
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
|
||||||
|
retval = optimizer.step(*args, **kwargs)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def step(self, optimizer, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
:meth:`step` carries out the following two operations:
|
||||||
|
|
||||||
|
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
|
||||||
|
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
|
||||||
|
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
|
||||||
|
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
|
||||||
|
|
||||||
|
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
|
||||||
|
|
||||||
|
Returns the return value of ``optimizer.step(*args, **kwargs)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
|
||||||
|
args: Any arguments.
|
||||||
|
kwargs: Any keyword arguments.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
Closure use is not currently supported.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return optimizer.step(*args, **kwargs)
|
||||||
|
|
||||||
|
if "closure" in kwargs:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Closure use is not currently supported if GradScaler is enabled."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._check_scale_growth_tracker("step")
|
||||||
|
|
||||||
|
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
||||||
|
|
||||||
|
if optimizer_state["stage"] is OptState.STEPPED:
|
||||||
|
raise RuntimeError(
|
||||||
|
"step() has already been called since the last update()."
|
||||||
|
)
|
||||||
|
|
||||||
|
retval = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(optimizer, "_step_supports_amp_scaling")
|
||||||
|
and optimizer._step_supports_amp_scaling
|
||||||
|
):
|
||||||
|
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
|
||||||
|
# The contract with custom optimizers is that their step() should accept an additional,
|
||||||
|
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
|
||||||
|
# it can query its own state, invoke unscale_ on itself, etc
|
||||||
|
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
|
||||||
|
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
|
||||||
|
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
|
||||||
|
# to skip the parameter updates or unscale gradients before updating parameters in
|
||||||
|
# the fused kernel, e.g. `FusedAdamMathFunctor`.
|
||||||
|
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
|
||||||
|
# while the method is expected to be called by users side, i.e. their optimizers.
|
||||||
|
kwargs_ = kwargs
|
||||||
|
has_grad_scaler_kwarg = (
|
||||||
|
"grad_scaler" in inspect.signature(optimizer.step).parameters
|
||||||
|
)
|
||||||
|
if has_grad_scaler_kwarg:
|
||||||
|
warnings.warn(
|
||||||
|
"GradScaler is going to stop passing itself as a keyword argument to the passed "
|
||||||
|
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
|
||||||
|
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
kwargs_.update({"grad_scaler": self})
|
||||||
|
else:
|
||||||
|
if optimizer_state["stage"] is OptState.READY:
|
||||||
|
self._check_inf_per_device(optimizer)
|
||||||
|
scaler = self._get_scale_async()
|
||||||
|
found_inf = cast(
|
||||||
|
torch.Tensor,
|
||||||
|
sum(
|
||||||
|
[
|
||||||
|
t.to(scaler.device, non_blocking=True)
|
||||||
|
for t in optimizer_state["found_inf_per_device"].values()
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
optimizer.grad_scale = (
|
||||||
|
None if optimizer_state["stage"] == OptState.UNSCALED else scaler
|
||||||
|
)
|
||||||
|
optimizer.found_inf = found_inf
|
||||||
|
retval = optimizer.step(*args, **kwargs_)
|
||||||
|
optimizer_state["stage"] = OptState.STEPPED
|
||||||
|
if not has_grad_scaler_kwarg:
|
||||||
|
del optimizer.grad_scale
|
||||||
|
del optimizer.found_inf
|
||||||
|
return retval
|
||||||
|
|
||||||
|
if optimizer_state["stage"] is OptState.READY:
|
||||||
|
self.unscale_(optimizer)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(optimizer_state["found_inf_per_device"]) > 0
|
||||||
|
), "No inf checks were recorded for this optimizer."
|
||||||
|
|
||||||
|
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
|
||||||
|
|
||||||
|
optimizer_state["stage"] = OptState.STEPPED
|
||||||
|
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def update(self, new_scale=None):
|
||||||
|
"""
|
||||||
|
Updates the scale factor.
|
||||||
|
|
||||||
|
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
||||||
|
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
||||||
|
the scale is multiplied by ``growth_factor`` to increase it.
|
||||||
|
|
||||||
|
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
||||||
|
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
||||||
|
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
||||||
|
affect the scale GradScaler uses internally.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_scale (float or :class:`torch.vacc.FloatTensor`, optional, default=None): New scale factor.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
||||||
|
been invoked for all optimizers used this iteration.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
For performance reasons, we do not check the scale factor value to avoid synchronizations,
|
||||||
|
so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or
|
||||||
|
you are seeing NaNs in your gradients or loss, something is likely wrong. For example,
|
||||||
|
bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
||||||
|
|
||||||
|
if new_scale is not None:
|
||||||
|
# Accept a new user-defined scale.
|
||||||
|
if isinstance(new_scale, float):
|
||||||
|
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
reason = "new_scale should be a float or a 1-element torch.vacc.FloatTensor with requires_grad=False."
|
||||||
|
# assert isinstance(new_scale, torch.vacc.FloatTensor), reason # type: ignore[attr-defined]
|
||||||
|
assert (
|
||||||
|
isinstance(new_scale, torch.Tensor)
|
||||||
|
and new_scale.dtype == torch.float32
|
||||||
|
), reason
|
||||||
|
assert new_scale.numel() == 1, reason
|
||||||
|
assert new_scale.requires_grad is False, reason
|
||||||
|
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||||
|
else:
|
||||||
|
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||||
|
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
||||||
|
found_infs = [
|
||||||
|
found_inf.to(device=_scale.device, non_blocking=True)
|
||||||
|
for state in self._per_optimizer_states.values()
|
||||||
|
for found_inf in state["found_inf_per_device"].values()
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||||
|
|
||||||
|
found_inf_combined = found_infs[0]
|
||||||
|
if len(found_infs) > 1:
|
||||||
|
for i in range(1, len(found_infs)):
|
||||||
|
found_inf_combined += found_infs[i]
|
||||||
|
|
||||||
|
torch._amp_update_scale_(
|
||||||
|
_scale,
|
||||||
|
_growth_tracker,
|
||||||
|
found_inf_combined,
|
||||||
|
self._growth_factor,
|
||||||
|
self._backoff_factor,
|
||||||
|
self._growth_interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
||||||
|
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
||||||
|
|
||||||
|
def _get_scale_async(self):
|
||||||
|
return self._scale
|
||||||
|
|
||||||
|
def get_scale(self):
|
||||||
|
"""
|
||||||
|
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:meth:`get_scale` incurs a CPU-GPU sync.
|
||||||
|
"""
|
||||||
|
if self._enabled:
|
||||||
|
return (
|
||||||
|
self._init_scale
|
||||||
|
if self._scale is None
|
||||||
|
else self._get_scale_async().item()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def get_growth_factor(self):
|
||||||
|
r"""
|
||||||
|
Returns a Python float containing the scale growth factor.
|
||||||
|
"""
|
||||||
|
return self._growth_factor
|
||||||
|
|
||||||
|
def set_growth_factor(self, new_factor):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
new_scale (float): Value to use as the new scale growth factor.
|
||||||
|
"""
|
||||||
|
self._growth_factor = new_factor
|
||||||
|
|
||||||
|
def get_backoff_factor(self):
|
||||||
|
r"""
|
||||||
|
Returns a Python float containing the scale backoff factor.
|
||||||
|
"""
|
||||||
|
return self._backoff_factor
|
||||||
|
|
||||||
|
def set_backoff_factor(self, new_factor):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
new_scale (float): Value to use as the new scale backoff factor.
|
||||||
|
"""
|
||||||
|
self._backoff_factor = new_factor
|
||||||
|
|
||||||
|
def get_growth_interval(self):
|
||||||
|
r"""
|
||||||
|
Returns a Python int containing the growth interval.
|
||||||
|
"""
|
||||||
|
return self._growth_interval
|
||||||
|
|
||||||
|
def set_growth_interval(self, new_interval):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
new_interval (int): Value to use as the new growth interval.
|
||||||
|
"""
|
||||||
|
self._growth_interval = new_interval
|
||||||
|
|
||||||
|
def _get_growth_tracker(self):
|
||||||
|
if self._enabled:
|
||||||
|
return (
|
||||||
|
self._init_growth_tracker
|
||||||
|
if self._growth_tracker is None
|
||||||
|
else self._growth_tracker.item()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def is_enabled(self):
|
||||||
|
r"""
|
||||||
|
Returns a bool indicating whether this instance is enabled.
|
||||||
|
"""
|
||||||
|
return self._enabled
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
r"""
|
||||||
|
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
||||||
|
|
||||||
|
* ``"scale"`` - a Python float containing the current scale
|
||||||
|
* ``"growth_factor"`` - a Python float containing the current growth factor
|
||||||
|
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
||||||
|
* ``"growth_interval"`` - a Python int containing the current growth interval
|
||||||
|
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
||||||
|
|
||||||
|
If this instance is not enabled, returns an empty dict.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
||||||
|
should be called after :meth:`update`.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"scale": self.get_scale(),
|
||||||
|
"growth_factor": self._growth_factor,
|
||||||
|
"backoff_factor": self._backoff_factor,
|
||||||
|
"growth_interval": self._growth_interval,
|
||||||
|
"_growth_tracker": self._get_growth_tracker(),
|
||||||
|
}
|
||||||
|
if self._enabled
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
r"""
|
||||||
|
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The source state dict is empty, possibly because it was saved "
|
||||||
|
"from a disabled instance of GradScaler."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_scale = state_dict["scale"]
|
||||||
|
if self._scale is not None:
|
||||||
|
self._scale.fill_(state_dict["scale"])
|
||||||
|
self._growth_factor = state_dict["growth_factor"]
|
||||||
|
self._backoff_factor = state_dict["backoff_factor"]
|
||||||
|
self._growth_interval = state_dict["growth_interval"]
|
||||||
|
self._init_growth_tracker = state_dict["_growth_tracker"]
|
||||||
|
if self._growth_tracker is not None:
|
||||||
|
self._growth_tracker.fill_(state_dict["_growth_tracker"])
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
if self._enabled:
|
||||||
|
assert len(self._per_optimizer_states) == 0, (
|
||||||
|
"A GradScaler instance may only be pickled at the beginning "
|
||||||
|
"of an iteration, or at the end after scaler.update()."
|
||||||
|
)
|
||||||
|
# Pickling _scale and _growth_tracker Tensors directly triggers
|
||||||
|
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
|
||||||
|
# so instead, we set the unpickled instance up to reinitialize them lazily.
|
||||||
|
state["_init_scale"] = self.get_scale()
|
||||||
|
state["_init_growth_tracker"] = self._get_growth_tracker()
|
||||||
|
state["_scale"] = None
|
||||||
|
state["_growth_tracker"] = None
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
|
||||||
|
def _check_inf_per_device(self, optimizer):
|
||||||
|
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
|
||||||
|
|
||||||
|
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
|
||||||
|
found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
|
||||||
|
|
||||||
|
self._per_optimizer_states[id(optimizer)][
|
||||||
|
"found_inf_per_device"
|
||||||
|
] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
|
||||||
|
|
||||||
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||||
|
|
||||||
|
def _found_inf_per_device(self, optimizer):
|
||||||
|
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
|
||||||
2819
torch_vacc/vacc/custom_ops.py
Normal file
2819
torch_vacc/vacc/custom_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
306
torch_vacc/vacc/custom_ops_cpu.py
Normal file
306
torch_vacc/vacc/custom_ops_cpu.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
from typing import Tuple, Union, Optional, List
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def split_last_two_dims_into_blocks(x, h, w):
|
||||||
|
leading_dims = x.shape[:-2]
|
||||||
|
H, W = x.shape[-2:]
|
||||||
|
assert (
|
||||||
|
H % h == 0 and W % w == 0
|
||||||
|
), "The last two dimensions must be divisible by block size."
|
||||||
|
x_reshaped = x.view(-1, 1, H, W)
|
||||||
|
|
||||||
|
unfolded = F.unfold(x_reshaped, kernel_size=(h, w), stride=(h, w))
|
||||||
|
unfolded = unfolded.view(-1, 1, h, w, H // h, W // w)
|
||||||
|
unfolded = unfolded.permute(0, 1, 4, 5, 2, 3)
|
||||||
|
final_shape = leading_dims + (H // h, W // w, h, w)
|
||||||
|
result = unfolded.view(final_shape)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def merge_blocks_to_original_layout(x, h, w):
|
||||||
|
leading_dims = x.shape[:-4]
|
||||||
|
H_div_h, W_div_w, h, w = x.shape[-4:]
|
||||||
|
H = H_div_h * h
|
||||||
|
W = W_div_w * w
|
||||||
|
|
||||||
|
x_reshaped = x.view(-1, 1, H_div_h, W_div_w, h, w)
|
||||||
|
x_reshaped = x_reshaped.permute(0, 1, 4, 5, 2, 3)
|
||||||
|
x_reshaped = x_reshaped.view(-1, h * w, H_div_h * W_div_w)
|
||||||
|
folded = F.fold(x_reshaped, output_size=(H, W), kernel_size=(h, w), stride=(h, w))
|
||||||
|
|
||||||
|
final_shape = leading_dims + (H, W)
|
||||||
|
result = folded.view(final_shape)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def w8a8_block_fp8_matmul(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor],
|
||||||
|
weight_scale: Optional[torch.Tensor],
|
||||||
|
block_size: List[int],
|
||||||
|
is_linear_weight: bool = False,
|
||||||
|
output_opt: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
b0, b1 = block_size
|
||||||
|
dim0, dim1 = weight.shape
|
||||||
|
dim0pad, dim1pad = 0, 0
|
||||||
|
|
||||||
|
if dim0 % b0 != 0:
|
||||||
|
dim0pad = b0 - dim0 % b0
|
||||||
|
if dim1 % b1 != 0:
|
||||||
|
dim1pad = b1 - dim1 % b1
|
||||||
|
|
||||||
|
dim0_origin, dim1_origin = dim0, dim1
|
||||||
|
dim0 += dim0pad
|
||||||
|
dim1 += dim1pad
|
||||||
|
|
||||||
|
bs0, bs1 = dim0 // b0, dim1 // b1
|
||||||
|
weight_dequant = torch.nn.functional.pad(weight, (0, dim1pad, 0, dim0pad), value=0)
|
||||||
|
weight_dequant = weight_dequant.cpu().view(bs0, b0, bs1, b1).permute(
|
||||||
|
0, 2, 1, 3
|
||||||
|
).reshape(bs0, bs1, -1).float().to(input.device) * weight_scale.unsqueeze(-1)
|
||||||
|
weight_dequant = (
|
||||||
|
weight_dequant.reshape(bs0, bs1, b0, b1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(dim0, dim1)
|
||||||
|
.to(input.dtype)
|
||||||
|
)
|
||||||
|
weight_dequant = weight_dequant[:dim0_origin, :dim1_origin]
|
||||||
|
output = torch.matmul(
|
||||||
|
input, weight_dequant.T if is_linear_weight else weight_dequant
|
||||||
|
)
|
||||||
|
if output_opt is not None:
|
||||||
|
output = output_opt.copy_(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def w8a8_block_fp8_linear(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor],
|
||||||
|
weight_scale: Optional[torch.Tensor],
|
||||||
|
block_size: List[int],
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
assert input_scale is None, "w8a8_block_fp8_matmul only support quant weight now"
|
||||||
|
return w8a8_block_fp8_matmul(
|
||||||
|
input, weight, None, weight_scale, block_size, is_linear_weight=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w13_weight: torch.Tensor,
|
||||||
|
w2_weight: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = True,
|
||||||
|
w13_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a13_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[List[int]] = None,
|
||||||
|
decode_with_batch: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_seq_all, hidden_dims = hidden_states.shape
|
||||||
|
intermediate_size = w2_weight.shape[-1]
|
||||||
|
num_experts = w13_weight.shape[0]
|
||||||
|
w13_weight = w13_weight.contiguous()
|
||||||
|
w2_weight = w2_weight.contiguous()
|
||||||
|
w13_scale = w13_scale.contiguous()
|
||||||
|
w2_scale = w2_scale.contiguous()
|
||||||
|
|
||||||
|
final_hidden_states = torch.zeros_like(hidden_states)
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
w1_scale = w13_scale
|
||||||
|
w2_scale = w2_scale
|
||||||
|
|
||||||
|
_, bs0_w13, bs1_w13 = w1_scale.shape
|
||||||
|
_, bs0_w2, bs1_w2 = w2_scale.shape
|
||||||
|
|
||||||
|
sel_experts = topk_ids.shape[1]
|
||||||
|
if hidden_states.shape[0] == 1:
|
||||||
|
for id in range(sel_experts):
|
||||||
|
expert_idx = topk_ids[0][id]
|
||||||
|
expert_w1 = w13_weight[expert_idx].contiguous()
|
||||||
|
expert_w2 = w2_weight[expert_idx].contiguous()
|
||||||
|
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
|
||||||
|
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
|
||||||
|
|
||||||
|
dim0, dim1 = expert_w1.shape
|
||||||
|
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
|
||||||
|
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
|
||||||
|
expert_w1 = (
|
||||||
|
expert_w1
|
||||||
|
.view(bs0_w13, b0, bs1_w13, b1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(bs0_w13, bs1_w13, -1)
|
||||||
|
.float()
|
||||||
|
.to(hidden_states.device)
|
||||||
|
* ws1
|
||||||
|
)
|
||||||
|
expert_w1 = (
|
||||||
|
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(dim0, dim1)
|
||||||
|
.to(hidden_states.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
dim0, dim1 = expert_w2.shape
|
||||||
|
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
|
||||||
|
# assert (bs0, bs1, 1)==ws2.shape
|
||||||
|
expert_w2 = (
|
||||||
|
expert_w2
|
||||||
|
.view(bs0_w2, b0, bs1_w2, b1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(bs0_w2, bs1_w2, -1)
|
||||||
|
.float()
|
||||||
|
.to(hidden_states.device)
|
||||||
|
* ws2
|
||||||
|
)
|
||||||
|
expert_w2 = (
|
||||||
|
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(dim0, dim1)
|
||||||
|
.to(hidden_states.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
|
||||||
|
|
||||||
|
x = hidden_states
|
||||||
|
x = F.linear(x, expert_w1)
|
||||||
|
gate = F.silu(x[:, :intermediate_size])
|
||||||
|
x = x[:, intermediate_size:] * gate
|
||||||
|
x = F.linear(x, expert_w2)
|
||||||
|
|
||||||
|
current_hidden_states = x * expert_weights
|
||||||
|
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||||
|
final_hidden_states += current_hidden_states
|
||||||
|
else:
|
||||||
|
for expert_idx in range(num_experts):
|
||||||
|
# topk_ids [tokens, experts] => sample:[10, 8]
|
||||||
|
# expert_mask [tokens, experts] => sample:[10, 8]
|
||||||
|
expert_mask = topk_ids == expert_idx
|
||||||
|
|
||||||
|
idx = torch.where(expert_mask)[0]
|
||||||
|
if idx.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
expert_w1 = w13_weight[expert_idx].contiguous()
|
||||||
|
expert_w2 = w2_weight[expert_idx].contiguous()
|
||||||
|
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
|
||||||
|
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
|
||||||
|
|
||||||
|
dim0, dim1 = expert_w1.shape
|
||||||
|
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
|
||||||
|
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
|
||||||
|
expert_w1 = (
|
||||||
|
expert_w1
|
||||||
|
.view(bs0_w13, b0, bs1_w13, b1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(bs0_w13, bs1_w13, -1)
|
||||||
|
.float()
|
||||||
|
.to(hidden_states.device)
|
||||||
|
* ws1
|
||||||
|
)
|
||||||
|
expert_w1 = (
|
||||||
|
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(dim0, dim1)
|
||||||
|
.to(hidden_states.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
dim0, dim1 = expert_w2.shape
|
||||||
|
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
|
||||||
|
# assert (bs0, bs1, 1)==ws2.shape
|
||||||
|
expert_w2 = (
|
||||||
|
expert_w2
|
||||||
|
.view(bs0_w2, b0, bs1_w2, b1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(bs0_w2, bs1_w2, -1)
|
||||||
|
.float()
|
||||||
|
.to(hidden_states.device)
|
||||||
|
* ws2
|
||||||
|
)
|
||||||
|
expert_w2 = (
|
||||||
|
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(dim0, dim1)
|
||||||
|
.to(hidden_states.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
# [seq, experts]
|
||||||
|
expert_weights = (
|
||||||
|
topk_weights.masked_select(expert_mask)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.to(hidden_states.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
x = hidden_states[idx]
|
||||||
|
x = F.linear(x, expert_w1)
|
||||||
|
gate = F.silu(x[:, :intermediate_size])
|
||||||
|
x = x[:, intermediate_size:] * gate
|
||||||
|
x = F.linear(x, expert_w2)
|
||||||
|
|
||||||
|
current_hidden_states = x * expert_weights
|
||||||
|
current_hidden_states = current_hidden_states.to(x.dtype)
|
||||||
|
# final_hidden_states[idx] += current_hidden_states
|
||||||
|
final_hidden_states.index_add_(0, idx, current_hidden_states)
|
||||||
|
|
||||||
|
final_hidden_states = final_hidden_states.reshape(batch_seq_all, hidden_dims)
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def fused_mlp_mm_fp8(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w13_weight: torch.Tensor,
|
||||||
|
w2_weight: torch.Tensor,
|
||||||
|
use_fp8_w8a8: bool = True,
|
||||||
|
w13_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a13_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape_w13: Optional[List[int]] = None,
|
||||||
|
block_shape_w2: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
def fp8_to_fp16(inp, scale, block_size, trans_type):
|
||||||
|
inp_t = inp.to(trans_type)
|
||||||
|
inp_t = split_last_two_dims_into_blocks(inp_t, block_size[0], block_size[1])
|
||||||
|
assert scale.size(0) == inp_t.size(-4)
|
||||||
|
assert scale.size(1) == inp_t.size(-3)
|
||||||
|
inp_t = inp_t * scale.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
inp_t = merge_blocks_to_original_layout(inp_t, block_size[0], block_size[1])
|
||||||
|
return inp_t.to(trans_type)
|
||||||
|
|
||||||
|
w13_weight = w13_weight.contiguous()
|
||||||
|
w2_weight = w2_weight.contiguous()
|
||||||
|
w13_scale = w13_scale.contiguous()
|
||||||
|
w2_scale = w2_scale.contiguous()
|
||||||
|
w13_fp = fp8_to_fp16(w13_weight, w13_scale, block_shape_w13, hidden_states.dtype)
|
||||||
|
w2_fp = fp8_to_fp16(w2_weight, w2_scale, block_shape_w2, hidden_states.dtype)
|
||||||
|
out = hidden_states @ w13_fp
|
||||||
|
out = torch.chunk(out, 2, dim=-1)
|
||||||
|
out = F.silu(out[0]) * out[1]
|
||||||
|
out = out @ w2_fp
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def mla_matmul_scale(input: torch.Tensor, weight: torch.Tensor, scale: float):
|
||||||
|
output = torch.matmul(input, weight)
|
||||||
|
output = output * scale
|
||||||
|
output = output.to(input.dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def mla_matmul(input: torch.Tensor, weight: torch.Tensor):
|
||||||
|
output = torch.matmul(input, weight)
|
||||||
|
output = output.to(input.dtype)
|
||||||
|
return output
|
||||||
146
torch_vacc/vacc/custom_qwen3_ops.py
Normal file
146
torch_vacc/vacc/custom_qwen3_ops.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Generator
|
||||||
|
from torch_vacc._vacc_libs import _torch_vacc
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_moe_prefill_stage0_qwen(
|
||||||
|
hidden_states,
|
||||||
|
rms_residual,
|
||||||
|
rms_weight,
|
||||||
|
gate_weight,
|
||||||
|
rms_hidden_state_opt: Optional[torch.Tensor] = None,
|
||||||
|
zero_moe_hidden_state_opt: Optional[torch.Tensor] = None,
|
||||||
|
topk_ids_opt: Optional[torch.Tensor] = None,
|
||||||
|
topk_weight_opt: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
return _torch_vacc.fuse_moe_prefill_stage0_qwen(
|
||||||
|
hidden_states,
|
||||||
|
rms_residual,
|
||||||
|
rms_weight,
|
||||||
|
gate_weight,
|
||||||
|
rms_hidden_state_opt,
|
||||||
|
zero_moe_hidden_state_opt,
|
||||||
|
topk_ids_opt,
|
||||||
|
topk_weight_opt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_moe_decode_qwen(
|
||||||
|
hidden_states,
|
||||||
|
rms_residual,
|
||||||
|
rms_weight,
|
||||||
|
moe_weight_13,
|
||||||
|
moe_weight_2,
|
||||||
|
moe_weight_13_dequat,
|
||||||
|
moe_weight_2_dequant,
|
||||||
|
gate_weight,
|
||||||
|
block_size_13,
|
||||||
|
block_size_2,
|
||||||
|
world_size: int,
|
||||||
|
rank: int,
|
||||||
|
group_id: int,
|
||||||
|
dev_info: List[int] = None,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if 0 == len(dev_info):
|
||||||
|
dev_info = [i | (i << 16) for i in range(world_size)]
|
||||||
|
return _torch_vacc.fuse_moe_decode_qwen(
|
||||||
|
hidden_states,
|
||||||
|
rms_residual,
|
||||||
|
rms_weight,
|
||||||
|
moe_weight_13,
|
||||||
|
moe_weight_2,
|
||||||
|
moe_weight_13_dequat,
|
||||||
|
moe_weight_2_dequant,
|
||||||
|
gate_weight,
|
||||||
|
block_size_13,
|
||||||
|
block_size_2,
|
||||||
|
world_size,
|
||||||
|
rank,
|
||||||
|
group_id,
|
||||||
|
dev_info,
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rot_pos_emb_qwenvl(grid_thw: List[List[int]],
|
||||||
|
hidden_size: int,
|
||||||
|
head_num: int,
|
||||||
|
spatial_merge_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: Union[int, str, torch.device] = "vacc"):
|
||||||
|
#assert out_tensor.device.type == "vacc", f"please target vacc device, now is {out_tensor.device}"
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
elif isinstance(device, int):
|
||||||
|
device = torch.device("vacc", device)
|
||||||
|
|
||||||
|
thws = []
|
||||||
|
for i in grid_thw:
|
||||||
|
thws.extend(i)
|
||||||
|
return _torch_vacc.rot_pos_emb_qwenvl(thws,
|
||||||
|
hidden_size,
|
||||||
|
head_num,
|
||||||
|
spatial_merge_size,
|
||||||
|
dtype,
|
||||||
|
device)
|
||||||
|
|
||||||
|
def fast_pos_embed_interpolate_qwenvl(weight: torch.Tensor,
|
||||||
|
grid_thw: List[List[int]],
|
||||||
|
num_grid_per_side: int,
|
||||||
|
spatial_merge_size: int,
|
||||||
|
hidden_dim: int):
|
||||||
|
thws = []
|
||||||
|
for i in grid_thw:
|
||||||
|
thws.extend(i)
|
||||||
|
return _torch_vacc.fast_pos_embed_interpolate_qwenvl(weight,
|
||||||
|
thws,
|
||||||
|
num_grid_per_side,
|
||||||
|
spatial_merge_size,
|
||||||
|
hidden_dim)
|
||||||
|
# qwen2_vl and qwen3_vl img preocess op is same
|
||||||
|
def qwen2vl_img_preprocess(
|
||||||
|
image: "torch.Tensor",
|
||||||
|
do_resize: bool,
|
||||||
|
min_pixels: int,
|
||||||
|
max_pixels: int,
|
||||||
|
do_rescale: bool,
|
||||||
|
rescale_factor: float,
|
||||||
|
do_normalize: bool,
|
||||||
|
resized_height: int,
|
||||||
|
resized_width: int,
|
||||||
|
interpolation: int, #Optional["F.InterpolationMode"],
|
||||||
|
patch_size: int,
|
||||||
|
temporal_patch_size: int,
|
||||||
|
merge_size: int,
|
||||||
|
image_mean0: float,
|
||||||
|
image_mean1: float,
|
||||||
|
image_mean2: float,
|
||||||
|
image_std0: float,
|
||||||
|
image_std1: float,
|
||||||
|
image_std2: float,
|
||||||
|
# batch_size: int = 1,
|
||||||
|
# grid_t: int = 1,
|
||||||
|
# channel: int = 3,
|
||||||
|
# output: Optional[torch.Tensor] = None
|
||||||
|
):
|
||||||
|
assert image.device.type == "vacc", f"please target vacc device, now is {image.device}"
|
||||||
|
return _torch_vacc.qwen2vl_img_preprocess(
|
||||||
|
image,
|
||||||
|
do_resize,
|
||||||
|
min_pixels,
|
||||||
|
max_pixels,
|
||||||
|
do_rescale,
|
||||||
|
rescale_factor,
|
||||||
|
do_normalize,
|
||||||
|
resized_height,
|
||||||
|
resized_width,
|
||||||
|
interpolation,
|
||||||
|
patch_size,
|
||||||
|
temporal_patch_size,
|
||||||
|
merge_size,
|
||||||
|
image_mean0, image_mean1, image_mean2,
|
||||||
|
image_std0, image_std1, image_std2
|
||||||
|
)
|
||||||
107
torch_vacc/vacc/lazy_initialize.py
Normal file
107
torch_vacc/vacc/lazy_initialize.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .._vacc_libs import _torch_vacc
|
||||||
|
|
||||||
|
_initialized = False
|
||||||
|
_tls = threading.local()
|
||||||
|
_initialization_lock = threading.Lock()
|
||||||
|
_queued_calls = []
|
||||||
|
|
||||||
|
_is_in_bad_fork = getattr(_torch_vacc, "_vacc_in_bad_fork", lambda: False)
|
||||||
|
|
||||||
|
|
||||||
|
def is_initialized():
|
||||||
|
r"""Returns whether PyTorch's VACC state has been initialized."""
|
||||||
|
return _initialized and not _is_in_bad_fork()
|
||||||
|
|
||||||
|
|
||||||
|
class _LazySeedTracker:
|
||||||
|
# Since seeding is memory-less, only track the latest seed.
|
||||||
|
# Note: `manual_seed_all` followed by `manual_seed` overwrites
|
||||||
|
# the seed on current device. We track the order of **latest**
|
||||||
|
# calls between these two API.
|
||||||
|
def __init__(self):
|
||||||
|
self.manual_seed_all_cb = None
|
||||||
|
self.manual_seed_cb = None
|
||||||
|
self.call_order = []
|
||||||
|
|
||||||
|
def queue_seed_all(self, cb, traceback):
|
||||||
|
self.manual_seed_all_cb = (cb, traceback)
|
||||||
|
# update seed_all to be latest
|
||||||
|
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||||
|
|
||||||
|
def queue_seed(self, cb, traceback):
|
||||||
|
self.manual_seed_cb = (cb, traceback)
|
||||||
|
# update seed to be latest
|
||||||
|
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||||
|
|
||||||
|
def get_calls(self) -> List:
|
||||||
|
return self.call_order
|
||||||
|
|
||||||
|
|
||||||
|
_lazy_seed_tracker = _LazySeedTracker()
|
||||||
|
|
||||||
|
|
||||||
|
def _lazy_call(callable, **kwargs):
|
||||||
|
if is_initialized():
|
||||||
|
callable()
|
||||||
|
else:
|
||||||
|
# TODO(torch_deploy): this accesses linecache, which attempts to read the
|
||||||
|
# file system to get traceback info. Patch linecache or do something
|
||||||
|
# else here if this ends up being important.
|
||||||
|
global _lazy_seed_tracker
|
||||||
|
if kwargs.get("seed_all", False):
|
||||||
|
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
|
||||||
|
elif kwargs.get("seed", False):
|
||||||
|
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
|
||||||
|
else:
|
||||||
|
# Don't store the actual traceback to avoid memory cycle
|
||||||
|
_queued_calls.append((callable, traceback.format_stack()))
|
||||||
|
|
||||||
|
|
||||||
|
class DeferredVaccCallError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _lazy_init():
|
||||||
|
"""Initialize VACC device state."""
|
||||||
|
|
||||||
|
global _initialized, _queued_calls
|
||||||
|
if _initialized or hasattr(_tls, "is_initializing"):
|
||||||
|
return
|
||||||
|
with _initialization_lock:
|
||||||
|
if _initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# It is important to prevent other threads from entering _lazy_init
|
||||||
|
# immediately, while we are still guaranteed to have the GIL, because some
|
||||||
|
# of the C calls we make below will release the GIL
|
||||||
|
if _is_in_bad_fork():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot re-initialize VACC in forked subprocess. To use VACC with "
|
||||||
|
"multiprocessing, you must use the 'spawn' start method"
|
||||||
|
)
|
||||||
|
|
||||||
|
_torch_vacc._vacc_init()
|
||||||
|
|
||||||
|
_tls.is_initializing = True
|
||||||
|
|
||||||
|
for calls in _lazy_seed_tracker.get_calls():
|
||||||
|
if calls:
|
||||||
|
_queued_calls.append(calls)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for queued_call, orig_traceback in _queued_calls:
|
||||||
|
try:
|
||||||
|
queued_call()
|
||||||
|
except Exception as e:
|
||||||
|
msg = (
|
||||||
|
f"VACC call failed lazily at initialization with error: {str(e)}\n\n"
|
||||||
|
f"VACC call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
||||||
|
)
|
||||||
|
raise DeferredVaccCallError(msg) from e
|
||||||
|
finally:
|
||||||
|
delattr(_tls, "is_initializing")
|
||||||
|
_initialized = True
|
||||||
535
torch_vacc/vacc/memory.py
Normal file
535
torch_vacc/vacc/memory.py
Normal file
@@ -0,0 +1,535 @@
|
|||||||
|
import collections
|
||||||
|
import contextlib
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._utils import _get_device_index
|
||||||
|
|
||||||
|
import torch_vacc
|
||||||
|
|
||||||
|
from torch_vacc._vacc_libs import _torch_vacc
|
||||||
|
from .lazy_initialize import is_initialized, _lazy_init
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"mem_get_info",
|
||||||
|
# "caching_allocator_alloc",
|
||||||
|
# "caching_allocator_delete",
|
||||||
|
"set_per_process_memory_fraction",
|
||||||
|
"empty_cache",
|
||||||
|
"memory_stats",
|
||||||
|
"memory_stats_as_nested_dict",
|
||||||
|
"reset_accumulated_memory_stats",
|
||||||
|
"reset_peak_memory_stats",
|
||||||
|
"reset_max_memory_allocated",
|
||||||
|
"reset_max_memory_cached",
|
||||||
|
"memory_allocated",
|
||||||
|
"max_memory_allocated",
|
||||||
|
"memory_reserved",
|
||||||
|
"max_memory_reserved",
|
||||||
|
"memory_cached",
|
||||||
|
"max_memory_cached",
|
||||||
|
"memory_snapshot",
|
||||||
|
"memory_summary",
|
||||||
|
"get_allocator_backend",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _free_mutex():
|
||||||
|
_torch_vacc._vacc_lock_mutex()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
_torch_vacc._vacc_unlock_mutex()
|
||||||
|
|
||||||
|
|
||||||
|
# def caching_allocator_alloc(size, device=None, stream=None):
|
||||||
|
# r"""Performs a memory allocation using the VACC memory allocator.
|
||||||
|
|
||||||
|
# Memory is allocated for a given device and a stream, this
|
||||||
|
# function is intended to be used for interoperability with other
|
||||||
|
# frameworks. Allocated memory is released through
|
||||||
|
# :func:`~torch_vacc.vacc.caching_allocator_delete`.
|
||||||
|
|
||||||
|
# Arguments:
|
||||||
|
# size (int): number of bytes to be allocated.
|
||||||
|
# device (torch.device or int, optional): selected device. If it is
|
||||||
|
# ``None`` the default VACC device is used.
|
||||||
|
# stream (torch_vacc.vacc.Stream or int, optional): selected stream. If is ``None`` then
|
||||||
|
# the default stream for the selected device is used.
|
||||||
|
# """
|
||||||
|
# if device is None:
|
||||||
|
# device = torch_vacc.vacc.current_device()
|
||||||
|
# device = _get_device_index(device)
|
||||||
|
# if stream is None:
|
||||||
|
# stream = torch_vacc.vacc.current_stream(device)
|
||||||
|
# if isinstance(stream, torch_vacc.vacc.streams.Stream):
|
||||||
|
# stream = stream.vacc_stream
|
||||||
|
# if not isinstance(stream, int):
|
||||||
|
# raise TypeError(
|
||||||
|
# "Invalid type for stream argument, must be "
|
||||||
|
# "`torch_vacc.vacc.Stream` or `int` representing a pointer "
|
||||||
|
# "to a exisiting stream"
|
||||||
|
# )
|
||||||
|
# with torch_vacc.vacc.device(device):
|
||||||
|
# return _torch_vacc._vacc_vaccCachingAllocator_raw_alloc(size, stream)
|
||||||
|
|
||||||
|
|
||||||
|
# def caching_allocator_delete(mem_ptr):
|
||||||
|
# r"""Deletes memory allocated using the VACC memory allocator.
|
||||||
|
|
||||||
|
# Memory allocated with :func:`~torch_vacc.vacc.caching_allocator_alloc`.
|
||||||
|
# is freed here. The associated device and stream are tracked inside
|
||||||
|
# the allocator.
|
||||||
|
|
||||||
|
# Arguments:
|
||||||
|
# mem_ptr (int): memory address to be freed by the allocator.
|
||||||
|
# """
|
||||||
|
# _torch_vacc._vacc_vaccCachingAllocator_raw_delete(mem_ptr)
|
||||||
|
|
||||||
|
|
||||||
|
def set_per_process_memory_fraction(fraction, device=None) -> None:
|
||||||
|
r"""Set memory fraction for a process.
|
||||||
|
The fraction is used to limit an caching allocator to allocated memory on a VACC device.
|
||||||
|
The allowed value equals the total visible memory multiplied fraction.
|
||||||
|
If trying to allocate more than the allowed value in a process, will raise an out of
|
||||||
|
memory error in allocator.
|
||||||
|
Arguments:
|
||||||
|
fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
|
||||||
|
device (torch.device or int, optional): selected device. If it is
|
||||||
|
``None`` the default VACC device is used.
|
||||||
|
.. note::
|
||||||
|
In general, the total available free memory is less than the total capacity.
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
if device is None:
|
||||||
|
device = torch_vacc.vacc.current_device()
|
||||||
|
device = _get_device_index(device)
|
||||||
|
if not isinstance(fraction, float):
|
||||||
|
raise TypeError("Invalid type for fraction argument, must be `float`")
|
||||||
|
if fraction < 0 or fraction > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid fraction value: {}. " "Allowed range: 0~1".format(fraction)
|
||||||
|
)
|
||||||
|
|
||||||
|
_torch_vacc._vacc_setMemoryFraction(fraction, device)
|
||||||
|
|
||||||
|
|
||||||
|
def empty_cache():
|
||||||
|
r"""Releases all unoccupied cached memory currently held by the caching
|
||||||
|
allocator so that those can be used in other VACC application and visible in
|
||||||
|
`nvidia-smi`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
:func:`~torch_vacc.vacc.empty_cache` doesn't increase the amount of VACC
|
||||||
|
memory available for PyTorch. However, it may help reduce fragmentation
|
||||||
|
of VACC memory in certain cases.
|
||||||
|
"""
|
||||||
|
if is_initialized():
|
||||||
|
_torch_vacc._vacc_emptyCache()
|
||||||
|
|
||||||
|
|
||||||
|
def memory_stats(device=None):
|
||||||
|
"""Returns a dictionary of VACC memory allocator statistics for a
|
||||||
|
given device.
|
||||||
|
The return value of this function is a dictionary of statistics, each of
|
||||||
|
which is a non-negative integer.
|
||||||
|
Core statistics:
|
||||||
|
- ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||||
|
number of allocation requests received by the memory allocator.
|
||||||
|
- ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||||
|
amount of allocated memory.
|
||||||
|
- ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||||
|
number of reserved segments from ``vaccMalloc()``.
|
||||||
|
- ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||||
|
amount of reserved memory.
|
||||||
|
- ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||||
|
number of active memory blocks.
|
||||||
|
- ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||||
|
amount of active memory.
|
||||||
|
- ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||||
|
number of inactive, non-releasable memory blocks.
|
||||||
|
- ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
|
||||||
|
amount of inactive, non-releasable memory.
|
||||||
|
For these core statistics, values are broken down as follows.
|
||||||
|
Pool type:
|
||||||
|
- ``all``: combined statistics across all memory pools.
|
||||||
|
- ``large_pool``: statistics for the large allocation pool
|
||||||
|
(as of October 2019, for size >= 1MB allocations).
|
||||||
|
- ``small_pool``: statistics for the small allocation pool
|
||||||
|
(as of October 2019, for size < 1MB allocations).
|
||||||
|
Metric type:
|
||||||
|
- ``current``: current value of this metric.
|
||||||
|
- ``peak``: maximum value of this metric.
|
||||||
|
- ``allocated``: historical total increase in this metric.
|
||||||
|
- ``freed``: historical total decrease in this metric.
|
||||||
|
In addition to the core statistics, we also provide some simple event
|
||||||
|
counters:
|
||||||
|
- ``"num_alloc_retries"``: number of failed ``vaccMalloc`` calls that
|
||||||
|
result in a cache flush and retry.
|
||||||
|
- ``"num_ooms"``: number of out-of-memory errors thrown.
|
||||||
|
The caching allocator can be configured via ENV to not split blocks larger than a
|
||||||
|
defined size (see Memory Management section of the Cuda Semantics documentation).
|
||||||
|
This helps avoid memory framentation but may have a performance
|
||||||
|
penalty. Additional outputs to assist with tuning and evaluating impact:
|
||||||
|
- ``"max_split_size"``: blocks above this size will not be split.
|
||||||
|
- ``"oversize_allocations.{current,peak,allocated,freed}"``:
|
||||||
|
number of over-size allocation requests received by the memory allocator.
|
||||||
|
- ``"oversize_segments.{current,peak,allocated,freed}"``:
|
||||||
|
number of over-size reserved segments from ``cudaMalloc()``.
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistics for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
|
||||||
|
def _recurse_add_to_result(prefix, obj):
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
if len(prefix) > 0:
|
||||||
|
prefix += "."
|
||||||
|
for k, v in obj.items():
|
||||||
|
_recurse_add_to_result(prefix + k, v)
|
||||||
|
else:
|
||||||
|
result.append((prefix, obj))
|
||||||
|
|
||||||
|
stats = memory_stats_as_nested_dict(device=device)
|
||||||
|
_recurse_add_to_result("", stats)
|
||||||
|
result.sort()
|
||||||
|
|
||||||
|
return collections.OrderedDict(result)
|
||||||
|
|
||||||
|
|
||||||
|
def memory_stats_as_nested_dict(device=None):
|
||||||
|
r"""Returns the result of :func:`~torch_vacc.vacc.memory_stats` as a nested dictionary."""
|
||||||
|
device = _get_device_index(device, optional=True)
|
||||||
|
return _torch_vacc._vacc_memoryStats(device)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_accumulated_memory_stats(device=None):
|
||||||
|
r"""Resets the "accumulated" (historical) stats tracked by the VACC memory allocator.
|
||||||
|
|
||||||
|
See :func:`~torch_vacc.vacc.memory_stats` for details. Accumulated stats correspond to
|
||||||
|
the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
|
||||||
|
`"num_alloc_retries"` and `"num_ooms"`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
"""
|
||||||
|
device = _get_device_index(device, optional=True)
|
||||||
|
return _torch_vacc._vacc_resetAccumulatedMemoryStats(device)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_peak_memory_stats(device=None):
|
||||||
|
r"""Resets the "peak" stats tracked by the VACC memory allocator.
|
||||||
|
|
||||||
|
See :func:`~torch_vacc.vacc.memory_stats` for details. Peak stats correspond to the
|
||||||
|
`"peak"` key in each individual stat dict.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
"""
|
||||||
|
device = _get_device_index(device, optional=True)
|
||||||
|
return _torch_vacc._vacc_resetPeakMemoryStats(device)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_max_memory_allocated(device=None):
|
||||||
|
r"""Resets the starting point in tracking maximum VACC memory occupied by
|
||||||
|
tensors for a given device.
|
||||||
|
|
||||||
|
See :func:`~torch_vacc.vacc.max_memory_allocated` for details.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This function now calls :func:`~torch_vacc.vacc.reset_peak_memory_stats`, which resets
|
||||||
|
/all/ peak memory stats.
|
||||||
|
"""
|
||||||
|
# warnings.warn(
|
||||||
|
# "torch_vacc.vacc.reset_max_memory_allocated now calls torch_vacc.vacc.reset_peak_memory_stats, "
|
||||||
|
# "which resets /all/ peak memory stats.",
|
||||||
|
# DeprecationWarning,
|
||||||
|
# )
|
||||||
|
return reset_peak_memory_stats(device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_max_memory_cached(device=None):
|
||||||
|
r"""Resets the starting point in tracking maximum VACC memory managed by the
|
||||||
|
caching allocator for a given device.
|
||||||
|
|
||||||
|
See :func:`~torch_vacc.vacc.max_memory_cached` for details.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This function now calls :func:`~torch_vacc.vacc.reset_peak_memory_stats`, which resets
|
||||||
|
/all/ peak memory stats.
|
||||||
|
"""
|
||||||
|
# warnings.warn(
|
||||||
|
# "torch_vacc.vacc.reset_max_memory_cached now calls torch_vacc.vacc.reset_peak_memory_stats, "
|
||||||
|
# "which resets /all/ peak memory stats.",
|
||||||
|
# DeprecationWarning,
|
||||||
|
# )
|
||||||
|
return reset_peak_memory_stats(device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def memory_allocated(device=None):
|
||||||
|
r"""Returns the current VACC memory occupied by tensors in bytes for a given
|
||||||
|
device.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
"""
|
||||||
|
return memory_stats(device=device)["allocated_bytes.all.current"]
|
||||||
|
|
||||||
|
|
||||||
|
def max_memory_allocated(device=None):
|
||||||
|
r"""Returns the maximum VACC memory occupied by tensors in bytes for a given
|
||||||
|
device.
|
||||||
|
|
||||||
|
By default, this returns the peak allocated memory since the beginning of
|
||||||
|
this program. :func:`~torch_vacc.vacc.reset_peak_stats` can be used to
|
||||||
|
reset the starting point in tracking this metric. For example, these two
|
||||||
|
functions can measure the peak allocated memory usage of each iteration in a
|
||||||
|
training loop.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
"""
|
||||||
|
return memory_stats(device=device)["allocated_bytes.all.peak"]
|
||||||
|
|
||||||
|
|
||||||
|
def memory_reserved(device=None):
|
||||||
|
r"""Returns the current VACC memory managed by the caching allocator in bytes
|
||||||
|
for a given device.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
"""
|
||||||
|
return memory_stats(device=device)["reserved_bytes.all.current"]
|
||||||
|
|
||||||
|
|
||||||
|
def max_memory_reserved(device=None):
|
||||||
|
r"""Returns the maximum VACC memory managed by the caching allocator in bytes
|
||||||
|
for a given device.
|
||||||
|
|
||||||
|
By default, this returns the peak cached memory since the beginning of this
|
||||||
|
program. :func:`~torch_vacc.vacc.reset_peak_stats` can be used to reset
|
||||||
|
the starting point in tracking this metric. For example, these two functions
|
||||||
|
can measure the peak cached memory amount of each iteration in a training
|
||||||
|
loop.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
"""
|
||||||
|
return memory_stats(device=device)["reserved_bytes.all.peak"]
|
||||||
|
|
||||||
|
|
||||||
|
def memory_cached(device=None):
|
||||||
|
r"""Deprecated; see :func:`~torch_vacc.vacc.memory_reserved`."""
|
||||||
|
# warnings.warn(
|
||||||
|
# "torch_vacc.vacc.memory_cached has been renamed to torch_vacc.vacc.memory_reserved",
|
||||||
|
# DeprecationWarning,
|
||||||
|
# )
|
||||||
|
return memory_reserved(device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def max_memory_cached(device=None):
|
||||||
|
r"""Deprecated; see :func:`~torch_vacc.vacc.max_memory_reserved`."""
|
||||||
|
# warnings.warn(
|
||||||
|
# "torch_vacc.vacc.max_memory_cached has been renamed to torch_vacc.vacc.max_memory_reserved",
|
||||||
|
# DeprecationWarning,
|
||||||
|
# )
|
||||||
|
return max_memory_reserved(device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def memory_snapshot():
|
||||||
|
r"""Returns a snapshot of the VACC memory allocator state across all devices.
|
||||||
|
|
||||||
|
Interpreting the output of this function requires familiarity with the
|
||||||
|
memory allocator internals.
|
||||||
|
"""
|
||||||
|
return _torch_vacc._vacc_memorySnapshot()
|
||||||
|
|
||||||
|
|
||||||
|
def _format_size(sz, pref_sz):
|
||||||
|
prefixes = ["B ", "KB", "MB", "GB", "TB", "PB"]
|
||||||
|
prefix = prefixes[0]
|
||||||
|
for new_prefix in prefixes[1:]:
|
||||||
|
if pref_sz < 768 * 1024:
|
||||||
|
break
|
||||||
|
prefix = new_prefix
|
||||||
|
sz //= 1024
|
||||||
|
pref_sz /= 1024
|
||||||
|
return "{:7d} {}".format(sz, prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_count(cnt, pref_cnt):
|
||||||
|
prefixes = [" ", "K", "M"]
|
||||||
|
prefix = prefixes[0]
|
||||||
|
for new_prefix in prefixes[1:]:
|
||||||
|
if pref_cnt < 750 * 1000:
|
||||||
|
break
|
||||||
|
prefix = new_prefix
|
||||||
|
cnt //= 1000
|
||||||
|
pref_cnt /= 1000
|
||||||
|
return "{:7d} {} ".format(cnt, prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def create_metrics_to_display():
|
||||||
|
metrics_to_display = [
|
||||||
|
("allocated_bytes", "Allocated memory", _format_size),
|
||||||
|
("active_bytes", "Active memory", _format_size),
|
||||||
|
("reserved_bytes", "VACC reserved memory", _format_size),
|
||||||
|
("inactive_split_bytes", "Non-releasable memory", _format_size),
|
||||||
|
("allocation", "Allocations", _format_count),
|
||||||
|
("active", "Active allocs", _format_count),
|
||||||
|
("segment", "VACC reserved segments", _format_count),
|
||||||
|
("inactive_split", "Non-releasable allocs", _format_count),
|
||||||
|
]
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
lines.append("=" * 75)
|
||||||
|
lines.append(" {_:16} PyTorch VACC memory summary, device ID {device:<18d} ")
|
||||||
|
lines.append("-" * 75)
|
||||||
|
lines.append(
|
||||||
|
" {_:9} VACC OOMs: {num_ooms:<13d} | {_:6} vaccMalloc retries: {num_alloc_retries:<9d} "
|
||||||
|
)
|
||||||
|
lines.append("=" * 75)
|
||||||
|
lines.append(
|
||||||
|
" Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed "
|
||||||
|
)
|
||||||
|
return metrics_to_display, lines
|
||||||
|
|
||||||
|
|
||||||
|
def memory_summary(device=None, abbreviated=False):
|
||||||
|
r"""Returns a human-readable printout of the current memory allocator
|
||||||
|
statistics for a given device.
|
||||||
|
|
||||||
|
This can be useful to display periodically during training, or when
|
||||||
|
handling out-of-memory exceptions.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
printout for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
abbreviated (bool, optional): whether to return an abbreviated summary
|
||||||
|
(default: False).
|
||||||
|
"""
|
||||||
|
device = _get_device_index(device, optional=True)
|
||||||
|
stats = memory_stats(device=device)
|
||||||
|
metrics_to_display, lines = create_metrics_to_display()
|
||||||
|
|
||||||
|
for metric_key, metric_name, formatter in metrics_to_display:
|
||||||
|
lines.append("-" * 75)
|
||||||
|
submetrics = [("all", metric_name)]
|
||||||
|
if not abbreviated:
|
||||||
|
submetrics.append(("large_pool", " from large pool"))
|
||||||
|
submetrics.append(("small_pool", " from small pool"))
|
||||||
|
|
||||||
|
current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
for submetric_key, submetric_name in submetrics:
|
||||||
|
prefix = metric_key + "." + submetric_key + "."
|
||||||
|
|
||||||
|
current = stats[prefix + "current"]
|
||||||
|
peak = stats[prefix + "peak"]
|
||||||
|
allocated = stats[prefix + "allocated"]
|
||||||
|
freed = stats[prefix + "freed"]
|
||||||
|
|
||||||
|
if current_prefval is None:
|
||||||
|
current_prefval = current
|
||||||
|
peak_prefval = peak
|
||||||
|
allocated_prefval = allocated
|
||||||
|
freed_prefval = freed
|
||||||
|
|
||||||
|
lines.append(
|
||||||
|
" {:<21} | {} | {} | {} | {} ".format(
|
||||||
|
submetric_name,
|
||||||
|
formatter(current, current_prefval),
|
||||||
|
formatter(peak, peak_prefval),
|
||||||
|
formatter(allocated, allocated_prefval),
|
||||||
|
formatter(freed, freed_prefval),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics_to_display = [
|
||||||
|
("oversize_allocations", "Oversize allocations", _format_count),
|
||||||
|
("oversize_segments", "Oversize VACC segments", _format_count),
|
||||||
|
]
|
||||||
|
|
||||||
|
for metric_key, metric_name, formatter in metrics_to_display:
|
||||||
|
lines.append("-" * 75)
|
||||||
|
|
||||||
|
prefix = metric_key + "."
|
||||||
|
|
||||||
|
current = stats[prefix + "current"]
|
||||||
|
peak = stats[prefix + "peak"]
|
||||||
|
allocated = stats[prefix + "allocated"]
|
||||||
|
freed = stats[prefix + "freed"]
|
||||||
|
|
||||||
|
lines.append(
|
||||||
|
" {:<21} | {} | {} | {} | {} ".format(
|
||||||
|
metric_name,
|
||||||
|
formatter(current, current),
|
||||||
|
formatter(peak, peak),
|
||||||
|
formatter(allocated, allocated),
|
||||||
|
formatter(freed, freed),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
lines.append("=" * 75)
|
||||||
|
|
||||||
|
fmt_dict = {"_": "", "device": device}
|
||||||
|
for k, v in stats.items():
|
||||||
|
fmt_dict[k.replace(".", "-")] = v
|
||||||
|
return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
|
||||||
|
|
||||||
|
|
||||||
|
def mem_get_info(device=None) -> Tuple[int, int]:
|
||||||
|
r"""Returns the global free and total VACC memory for a given
|
||||||
|
device using vaccrtMemGetInfo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
|
||||||
|
if :attr:`device` is ``None`` (default).
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
if device is None:
|
||||||
|
device = torch_vacc.vacc.current_device()
|
||||||
|
device = _get_device_index(device)
|
||||||
|
return _torch_vacc._vacc_getDeviceMemories(device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_allocator_backend() -> str:
|
||||||
|
r"""Returns a string describing the active allocator backend as set by
|
||||||
|
``PYTORCH_VACC_ALLOC_CONF``. Currently available backends are
|
||||||
|
``native`` (PyTorch's native caching allocator).
|
||||||
|
"""
|
||||||
|
return _torch_vacc._vacc_getAllocatorBackend()
|
||||||
179
torch_vacc/vacc/random.py
Normal file
179
torch_vacc/vacc/random.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
from typing import Union, List, Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from . import _lazy_call, _lazy_init, current_device, device_count
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_rng_state",
|
||||||
|
"get_rng_state_all",
|
||||||
|
"set_rng_state",
|
||||||
|
"set_rng_state_all",
|
||||||
|
"manual_seed",
|
||||||
|
"manual_seed_all",
|
||||||
|
"seed",
|
||||||
|
"seed_all",
|
||||||
|
"initial_seed",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Random Number Generator related functions (https://pytorch.org/docs/stable/cuda.html#random-number-generator)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rng_state(device: Union[int, str, torch.device] = "vacc") -> Tensor:
|
||||||
|
r"""Returns the random number generator state of the specified GPU as a ByteTensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device or int, optional): The device to return the RNG state of.
|
||||||
|
Default: ``'vacc'`` (i.e., ``torch.device('vacc')``, the current VACC device).
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This function eagerly initializes VACC.
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
elif isinstance(device, int):
|
||||||
|
device = torch.device("vacc", device)
|
||||||
|
idx = device.index
|
||||||
|
if idx is None:
|
||||||
|
idx = current_device()
|
||||||
|
default_generator = torch.vacc.default_generators[idx]
|
||||||
|
return default_generator.get_state()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rng_state_all() -> List[Tensor]:
|
||||||
|
r"""Returns a list of ByteTensor representing the random number states of all devices."""
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(device_count()):
|
||||||
|
results.append(get_rng_state(i))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def set_rng_state(
|
||||||
|
new_state: Tensor, device: Union[int, str, torch.device] = "vacc"
|
||||||
|
) -> None:
|
||||||
|
r"""Sets the random number generator state of the specified GPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_state (torch.ByteTensor): The desired state
|
||||||
|
device (torch.device or int, optional): The device to set the RNG state.
|
||||||
|
Default: ``'vacc'`` (i.e., ``torch.device('vacc')``, the current VACC device).
|
||||||
|
"""
|
||||||
|
with torch._C._DisableFuncTorch():
|
||||||
|
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
elif isinstance(device, int):
|
||||||
|
device = torch.device("vacc", device)
|
||||||
|
|
||||||
|
def cb():
|
||||||
|
idx = device.index
|
||||||
|
if idx is None:
|
||||||
|
idx = current_device()
|
||||||
|
default_generator = torch.vacc.default_generators[idx]
|
||||||
|
default_generator.set_state(new_state_copy)
|
||||||
|
|
||||||
|
_lazy_call(cb)
|
||||||
|
|
||||||
|
|
||||||
|
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
|
||||||
|
r"""Sets the random number generator state of all devices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_states (Iterable of torch.ByteTensor): The desired state for each device"""
|
||||||
|
for i, state in enumerate(new_states):
|
||||||
|
set_rng_state(state, i)
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed: int) -> None:
|
||||||
|
r"""Sets the seed for generating random numbers for the current GPU.
|
||||||
|
It's safe to call this function if VACC is not available; in that
|
||||||
|
case, it is silently ignored.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): The desired seed.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
If you are working with a multi-GPU model, this function is insufficient
|
||||||
|
to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
|
||||||
|
"""
|
||||||
|
seed = int(seed)
|
||||||
|
|
||||||
|
def cb():
|
||||||
|
idx = current_device()
|
||||||
|
default_generator = torch.vacc.default_generators[idx]
|
||||||
|
default_generator.manual_seed(seed)
|
||||||
|
|
||||||
|
_lazy_call(cb, seed=True)
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed_all(seed: int) -> None:
|
||||||
|
r"""Sets the seed for generating random numbers on all GPUs.
|
||||||
|
It's safe to call this function if VACC is not available; in that
|
||||||
|
case, it is silently ignored.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): The desired seed.
|
||||||
|
"""
|
||||||
|
seed = int(seed)
|
||||||
|
|
||||||
|
def cb():
|
||||||
|
for i in range(device_count()):
|
||||||
|
default_generator = torch.vacc.default_generators[i]
|
||||||
|
default_generator.manual_seed(seed)
|
||||||
|
|
||||||
|
_lazy_call(cb, seed_all=True)
|
||||||
|
|
||||||
|
|
||||||
|
def seed() -> None:
|
||||||
|
r"""Sets the seed for generating random numbers to a random number for the current GPU.
|
||||||
|
It's safe to call this function if VACC is not available; in that
|
||||||
|
case, it is silently ignored.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
If you are working with a multi-GPU model, this function will only initialize
|
||||||
|
the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cb():
|
||||||
|
idx = current_device()
|
||||||
|
default_generator = torch.vacc.default_generators[idx]
|
||||||
|
default_generator.seed()
|
||||||
|
|
||||||
|
_lazy_call(cb)
|
||||||
|
|
||||||
|
|
||||||
|
def seed_all() -> None:
|
||||||
|
r"""Sets the seed for generating random numbers to a random number on all GPUs.
|
||||||
|
It's safe to call this function if VACC is not available; in that
|
||||||
|
case, it is silently ignored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def cb():
|
||||||
|
random_seed = 0
|
||||||
|
seeded = False
|
||||||
|
for i in range(device_count()):
|
||||||
|
default_generator = torch.vacc.default_generators[i]
|
||||||
|
if not seeded:
|
||||||
|
default_generator.seed()
|
||||||
|
random_seed = default_generator.initial_seed()
|
||||||
|
seeded = True
|
||||||
|
else:
|
||||||
|
default_generator.manual_seed(random_seed)
|
||||||
|
|
||||||
|
_lazy_call(cb)
|
||||||
|
|
||||||
|
|
||||||
|
def initial_seed() -> int:
|
||||||
|
r"""Returns the current random seed of the current GPU.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This function eagerly initializes VACC.
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
idx = current_device()
|
||||||
|
default_generator = torch.vacc.default_generators[idx]
|
||||||
|
return default_generator.initial_seed()
|
||||||
327
torch_vacc/vacc/streams.py
Normal file
327
torch_vacc/vacc/streams.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
import ctypes
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from torch._utils import _get_device_index
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch._streambase import _StreamBase, _EventBase
|
||||||
|
except ImportError:
|
||||||
|
# torch <= 2.1
|
||||||
|
_StreamBase = _EventBase = object
|
||||||
|
|
||||||
|
import torch_vacc
|
||||||
|
|
||||||
|
from torch_vacc._vacc_libs import _torch_vacc
|
||||||
|
from ._device import device
|
||||||
|
from .lazy_initialize import _lazy_init
|
||||||
|
|
||||||
|
|
||||||
|
# remove torch version arch-suffix(i.e. +cpu)
|
||||||
|
torch_version = torch.__version__.split('+')[0]
|
||||||
|
|
||||||
|
class _StreamCommon:
|
||||||
|
"""Wrapper around a VACC stream.
|
||||||
|
|
||||||
|
A VACC stream is a linear sequence of execution that belongs to a specific
|
||||||
|
device, independent from other streams.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device(torch.device or int, optional): a device on which to allocate
|
||||||
|
the stream. If :attr:`device` is ``None`` (default) or a negative
|
||||||
|
integer, this will use the current device.
|
||||||
|
priority(int, optional): priority of the stream. Can be either
|
||||||
|
-1 (high priority) or 0 (low priority). By default, streams have
|
||||||
|
priority 0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, device=None, priority=0, **kwargs):
|
||||||
|
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
|
||||||
|
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
|
||||||
|
else:
|
||||||
|
with torch_vacc.vacc.device(device):
|
||||||
|
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
|
||||||
|
|
||||||
|
def wait_event(self, event):
|
||||||
|
event.wait(self)
|
||||||
|
|
||||||
|
def record_event(self, event=None):
|
||||||
|
"""Records an event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (torch_vacc.Event, optional): event to record. If not given, a new one
|
||||||
|
will be allocated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Recorded event.
|
||||||
|
"""
|
||||||
|
if event is None:
|
||||||
|
event = Event()
|
||||||
|
event.record(self)
|
||||||
|
return event
|
||||||
|
|
||||||
|
def wait_stream(self, stream):
|
||||||
|
"""Synchronizes with another stream.
|
||||||
|
|
||||||
|
All future work submitted to this stream will wait until all kernels
|
||||||
|
submitted to a given stream at the time of call complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream (Stream): a stream to synchronize.
|
||||||
|
"""
|
||||||
|
self.wait_event(stream.record_event())
|
||||||
|
|
||||||
|
def query(self):
|
||||||
|
return super().query()
|
||||||
|
|
||||||
|
def synchronize(self):
|
||||||
|
super().synchronize()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _as_parameter_(self):
|
||||||
|
return ctypes.c_void_p(self.vacc_stream)
|
||||||
|
|
||||||
|
def __eq__(self, o):
|
||||||
|
if isinstance(o, Stream):
|
||||||
|
return super().__eq__(o)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.vacc_stream, self.device))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"torch_vacc.vacc.Stream device={self.device} vacc_stream={self.vacc_stream:#x}"
|
||||||
|
|
||||||
|
if version.parse(torch_version) <= version.parse("2.1"):
|
||||||
|
# torch <= 2.1
|
||||||
|
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
|
||||||
|
pass
|
||||||
|
elif version.parse(torch_version) < version.parse("2.6"):
|
||||||
|
# torch < 2.6
|
||||||
|
class Stream(_torch_vacc._VACCStreamBase, _StreamBase, _StreamCommon):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# torch >= 2.6
|
||||||
|
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _EventCommon:
|
||||||
|
"""Wrapper around a VACC event.
|
||||||
|
|
||||||
|
VACC events are synchronization markers that can be used to monitor the
|
||||||
|
device's progress, to accurately measure timing, and to synchronize VACC
|
||||||
|
streams.
|
||||||
|
|
||||||
|
The underlying VACC events are lazily initialized when the event is first
|
||||||
|
recorded or exported to another process. After creation, only streams on the
|
||||||
|
same device may record the event. However, streams on any device can wait on
|
||||||
|
the event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calc_time (bool, optional): indicates if the event should measure time
|
||||||
|
(default: ``False``)
|
||||||
|
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, enable_timing=False, blocking=False):
|
||||||
|
return super(Event, cls).__new__(
|
||||||
|
cls,
|
||||||
|
calc_time=enable_timing,
|
||||||
|
blocking=blocking,
|
||||||
|
)
|
||||||
|
|
||||||
|
def record(self, stream=None):
|
||||||
|
"""Records the event in a given stream.
|
||||||
|
|
||||||
|
Uses ``torch_vacc.vacc.current_stream()`` if no stream is specified. The
|
||||||
|
stream's device must match the event's device."""
|
||||||
|
if stream is None:
|
||||||
|
stream = torch_vacc.vacc.current_stream()
|
||||||
|
super().record(stream)
|
||||||
|
|
||||||
|
def wait(self, stream=None):
|
||||||
|
"""Makes all future work submitted to the given stream wait for this
|
||||||
|
event.
|
||||||
|
|
||||||
|
Use ``torch_vacc.vacc.current_stream()`` if no stream is specified.
|
||||||
|
|
||||||
|
.. note:: This is a wrapper around ``vaccrtStreamWaitEvent()``
|
||||||
|
"""
|
||||||
|
if stream is None:
|
||||||
|
stream = torch_vacc.vacc.current_stream()
|
||||||
|
super().wait(stream)
|
||||||
|
|
||||||
|
def query(self):
|
||||||
|
"""Checks if all work currently captured by event has completed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A boolean indicating if all work currently captured by event has
|
||||||
|
completed.
|
||||||
|
"""
|
||||||
|
return super().query()
|
||||||
|
|
||||||
|
def elapsed_time(self, end_event):
|
||||||
|
"""Returns the time elapsed in milliseconds after the event was
|
||||||
|
recorded and before the end_event was recorded.
|
||||||
|
"""
|
||||||
|
return super().elapsed_time(end_event)
|
||||||
|
|
||||||
|
def synchronize(self):
|
||||||
|
r"""Waits for the event to complete.
|
||||||
|
|
||||||
|
Waits until the completion of all work currently captured in this event.
|
||||||
|
This prevents the CPU thread from proceeding until the event completes.
|
||||||
|
|
||||||
|
.. note:: This is a wrapper around ``vaccEventSynchronize()``.
|
||||||
|
"""
|
||||||
|
super().synchronize()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _as_parameter_(self):
|
||||||
|
return ctypes.c_void_p(self.vacc_event)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self.vacc_event:
|
||||||
|
return f"<torch_vacc.vacc.Event {self._as_parameter_.value:#x}>"
|
||||||
|
else:
|
||||||
|
return "<torch_vacc.vacc.Event uninitialized>"
|
||||||
|
|
||||||
|
if version.parse(torch_version) <= version.parse("2.1"):
|
||||||
|
# torch <= 2.1
|
||||||
|
class Event(_torch_vacc._VACCEventBase, _EventCommon):
|
||||||
|
pass
|
||||||
|
elif version.parse(torch_version) < version.parse("2.6"):
|
||||||
|
# torch < 2.6
|
||||||
|
class Event(_torch_vacc._VACCEventBase, _EventBase, _EventCommon):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# torch >= 2.6
|
||||||
|
class Event(_torch_vacc._VACCEventBase, _EventCommon):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class StreamContext:
|
||||||
|
r"""Context-manager that selects a given stream.
|
||||||
|
|
||||||
|
All VACC kernels queued within its context will be enqueued on a selected
|
||||||
|
stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream (stream): selected stream. This manager is a no-op if it's
|
||||||
|
``None``.
|
||||||
|
.. note:: Streams are per-device.
|
||||||
|
"""
|
||||||
|
cur_stream: Optional["torch_vacc.vacc.Stream"]
|
||||||
|
|
||||||
|
def __init__(self, stream: Optional["torch_vacc.vacc.Stream"]):
|
||||||
|
self.stream = stream
|
||||||
|
self.idx = _get_device_index(None, True)
|
||||||
|
if not torch.jit.is_scripting():
|
||||||
|
if self.idx is None:
|
||||||
|
self.idx = -1
|
||||||
|
|
||||||
|
self.src_prev_stream = (
|
||||||
|
None
|
||||||
|
if not torch.jit.is_scripting()
|
||||||
|
else torch_vacc.vacc.default_stream(None)
|
||||||
|
)
|
||||||
|
self.dst_prev_stream = (
|
||||||
|
None
|
||||||
|
if not torch.jit.is_scripting()
|
||||||
|
else torch_vacc.vacc.default_stream(None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# Local cur_stream variable for type refinement
|
||||||
|
cur_stream = self.stream
|
||||||
|
# Return if stream is None or VACC device not available
|
||||||
|
if cur_stream is None or self.idx == -1:
|
||||||
|
return
|
||||||
|
self.src_prev_stream = torch_vacc.vacc.current_stream(None)
|
||||||
|
|
||||||
|
# If the stream is not on the current device, then
|
||||||
|
# set the current stream on the device
|
||||||
|
if self.src_prev_stream.device != cur_stream.device:
|
||||||
|
with device(cur_stream.device):
|
||||||
|
self.dst_prev_stream = torch_vacc.vacc.current_stream(cur_stream.device)
|
||||||
|
torch_vacc.vacc.set_stream(cur_stream)
|
||||||
|
|
||||||
|
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||||
|
# Local cur_stream variable for type refinement
|
||||||
|
cur_stream = self.stream
|
||||||
|
# If stream is None or no VACC device available, return
|
||||||
|
if cur_stream is None or self.idx == -1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reset the stream on the original device
|
||||||
|
# and destination device
|
||||||
|
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
|
||||||
|
torch_vacc.vacc.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
|
||||||
|
torch_vacc.vacc.set_stream(self.src_prev_stream) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def stream(stream: Optional["torch_vacc.vacc.Stream"]) -> StreamContext:
|
||||||
|
r"""Wrapper around the Context-manager StreamContext that
|
||||||
|
selects a given stream.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
stream (Stream): selected stream. This manager is a no-op if it's
|
||||||
|
``None``.
|
||||||
|
"""
|
||||||
|
return StreamContext(stream)
|
||||||
|
|
||||||
|
|
||||||
|
def set_stream(stream: Stream):
|
||||||
|
r"""Sets the current stream.This is a wrapper API to set the stream.
|
||||||
|
Usage of this function is discouraged in favor of the ``stream``
|
||||||
|
context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream (Stream): selected stream. This function is a no-op
|
||||||
|
if this argument is ``None``.
|
||||||
|
"""
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
_torch_vacc._vacc_setStream(
|
||||||
|
stream_id=stream.stream_id,
|
||||||
|
device_index=stream.device_index,
|
||||||
|
device_type=stream.device_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def current_stream(device=None) -> Stream:
|
||||||
|
r"""Returns the currently selected :class:`Stream` for a given device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
the currently selected :class:`Stream` for the current device, given
|
||||||
|
by :func:`~torch_vacc.vacc.current_device`, if :attr:`device` is ``None``
|
||||||
|
(default).
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
streamdata = _torch_vacc._vacc_getCurrentStream(
|
||||||
|
_get_device_index(device, optional=True)
|
||||||
|
)
|
||||||
|
return Stream(
|
||||||
|
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def default_stream(device=None) -> Stream:
|
||||||
|
r"""Returns the default :class:`Stream` for a given device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
the default :class:`Stream` for the current device, given by
|
||||||
|
:func:`_torch_vacc.current_device`, if :attr:`device` is ``None``
|
||||||
|
(default).
|
||||||
|
"""
|
||||||
|
_lazy_init()
|
||||||
|
streamdata = _torch_vacc._vacc_getDefaultStream(
|
||||||
|
_get_device_index(device, optional=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
return Stream(
|
||||||
|
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
|
||||||
|
)
|
||||||
2
torch_vacc/version.py
Normal file
2
torch_vacc/version.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
__all__ = ['__version__']
|
||||||
|
__version__ = '1.3.3.777'
|
||||||
269
torch_vacc/vslog.cfg
Normal file
269
torch_vacc/vslog.cfg
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
hot_update: true
|
||||||
|
|
||||||
|
- channel: 0
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "$PNAME-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: false
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 1
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "vacm-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: false
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 2
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "vace-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: false
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 3
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "vacl-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: false
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 4
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "vame-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: false
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 5
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "vaml-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: false
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 6
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
append_cr: true
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "rt-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: true
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 7
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "nn-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: true
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 8
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "tm-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: false
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 9
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
append_cr: true
|
||||||
|
no_prefix: true
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "md-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: true
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 10
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
append_cr: false
|
||||||
|
no_prefix: true
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "rs-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: false
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 11
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
append_cr: false
|
||||||
|
no_prefix: true
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "vaapi-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: false
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
- channel: 12
|
||||||
|
sync: sync
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
-device: 0
|
||||||
|
disable: false
|
||||||
|
out_type: file
|
||||||
|
priority: error
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
|
path: "./log/"
|
||||||
|
file: "vccl-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
|
||||||
|
rollback: 5
|
||||||
|
limit_size: 50 m #only support M byte
|
||||||
|
-device: 1
|
||||||
|
disable: true
|
||||||
|
out_type: screen
|
||||||
|
category: 0
|
||||||
|
category_extend: 0
|
||||||
31
vacc_tools/__init__.py
Normal file
31
vacc_tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from functools import partial
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Union, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
_module_time = {}
|
||||||
|
|
||||||
|
|
||||||
|
def print_module_time(
|
||||||
|
model: torch.nn.Module, module: Union[Tuple[torch.nn.Module], torch.nn.Module]
|
||||||
|
):
|
||||||
|
def now_as_us():
|
||||||
|
return int(datetime.now().timestamp() * 1e6) # in us
|
||||||
|
|
||||||
|
def _pre_forward(suffix, m, inputs):
|
||||||
|
name = f"{type(m).__name__}.{suffix}"
|
||||||
|
_module_time[name] = now_as_us()
|
||||||
|
|
||||||
|
def _post_forward(suffix, m, inputs, outputs):
|
||||||
|
name = f"{type(m).__name__}.{suffix}"
|
||||||
|
start_time = _module_time.pop(name)
|
||||||
|
print(f"{name}: {now_as_us() - start_time} us")
|
||||||
|
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
if isinstance(m, module):
|
||||||
|
m.register_forward_pre_hook(partial(_pre_forward, "forward"))
|
||||||
|
m.register_forward_hook(partial(_post_forward, "forward"))
|
||||||
|
m.register_full_backward_pre_hook(partial(_pre_forward, "backward"))
|
||||||
|
m.register_full_backward_hook(partial(_post_forward, "backward"))
|
||||||
BIN
vacc_tools/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vacc_tools/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vacc_tools/__pycache__/generate_trace.cpython-312.pyc
Normal file
BIN
vacc_tools/__pycache__/generate_trace.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vacc_tools/__pycache__/memory_analyzer.cpython-312.pyc
Normal file
BIN
vacc_tools/__pycache__/memory_analyzer.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vacc_tools/__pycache__/trace_logger.cpython-312.pyc
Normal file
BIN
vacc_tools/__pycache__/trace_logger.cpython-312.pyc
Normal file
Binary file not shown.
214
vacc_tools/generate_trace.py
Normal file
214
vacc_tools/generate_trace.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Generating tracing json files from log files.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m vacc_tools.generate_trace --log-dir <directory of log files> --out-file-prefix <prefix of output file>
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import tabulate
|
||||||
|
from glob import glob
|
||||||
|
from collections import defaultdict
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
|
||||||
|
def run_stats_on_traces(timelines):
|
||||||
|
op_cat_list = ["ODSP", "DLC", "VCCL", "CPU", "CPU_OP"]
|
||||||
|
op_stats = {op: {} for op in op_cat_list}
|
||||||
|
for line in timelines:
|
||||||
|
if '"E"' not in line: # optim 3, skip everything if not `"E"`
|
||||||
|
continue
|
||||||
|
|
||||||
|
# optim 2: using `[:-2]` instead of replace()
|
||||||
|
line = line[:-2] # remove ',\n'
|
||||||
|
try:
|
||||||
|
values = json.loads(line)
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
# some log may not ends properly, just skip it
|
||||||
|
continue
|
||||||
|
|
||||||
|
if values["ph"] == "E" and values["cat"] in op_cat_list:
|
||||||
|
cat = values["cat"]
|
||||||
|
if values["name"] not in op_stats[cat]:
|
||||||
|
op_stats[cat][values["name"]] = []
|
||||||
|
if "dur" in values["args"]:
|
||||||
|
# optim 1: using `[:-2]` instead of replace()
|
||||||
|
op_stats[cat][values["name"]].append(
|
||||||
|
int(values["args"]["dur"][:-2]) # strip `us`
|
||||||
|
)
|
||||||
|
elif "values(us)" in values["args"]:
|
||||||
|
op_stats[cat][values["name"]].append(values["args"]["value(us)"])
|
||||||
|
op_tables = {}
|
||||||
|
for cat, stats in op_stats.items():
|
||||||
|
# optim 4: using list comprehension instead of for loop
|
||||||
|
table = []
|
||||||
|
for name, dur in stats.items():
|
||||||
|
dur = np.array(dur)
|
||||||
|
t = [
|
||||||
|
name,
|
||||||
|
np.min(dur),
|
||||||
|
np.max(dur),
|
||||||
|
np.sum(dur),
|
||||||
|
np.mean(dur),
|
||||||
|
np.percentile(dur, 90),
|
||||||
|
len(dur),
|
||||||
|
]
|
||||||
|
table.append(t)
|
||||||
|
|
||||||
|
table = sorted(table, key=lambda x: x[-1], reverse=True)
|
||||||
|
op_tables[cat] = tabulate.tabulate(
|
||||||
|
table,
|
||||||
|
headers=["op", "min", "max", "sum", "avg", "p90", "count"],
|
||||||
|
tablefmt="plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
if cat in ["VCCL", "ODSP", "DLC"]:
|
||||||
|
op_tables["VACC-ALL"] = op_tables.get("VACC-ALL", []) + [
|
||||||
|
t + [cat] for t in table
|
||||||
|
]
|
||||||
|
|
||||||
|
total = sum([x[3] for x in op_tables["VACC-ALL"]])
|
||||||
|
op_tables["VACC-ALL"] = [t + [t[3] / total * 100] for t in op_tables["VACC-ALL"]]
|
||||||
|
|
||||||
|
op_tables["VACC-ALL"] = tabulate.tabulate(
|
||||||
|
sorted(op_tables["VACC-ALL"], key=lambda x: x[-1], reverse=True),
|
||||||
|
headers=["op", "min", "max", "sum", "avg", "p90", "count", "cat", "percent(%)"],
|
||||||
|
tablefmt="plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
return op_tables
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank_info(files):
|
||||||
|
# using pattern rank-<rank> in file name to get rank
|
||||||
|
for fpath in files:
|
||||||
|
rank = re.findall(r"rank-(\d+)", fpath)
|
||||||
|
if rank:
|
||||||
|
return int(rank[0])
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def extract_traces(arg):
|
||||||
|
files, target_file_path, group_name, trace_token = arg
|
||||||
|
|
||||||
|
entries = [
|
||||||
|
(0, "scheduler"),
|
||||||
|
(1, "megatron"),
|
||||||
|
(2, "deepspeed"),
|
||||||
|
(3, "nn.Module"),
|
||||||
|
(10, "vacc-odsp"),
|
||||||
|
(11, "vacc-dlc"),
|
||||||
|
(12, "vacc-vccl"),
|
||||||
|
(13, "vacc-cpu"),
|
||||||
|
(14, "vacc-fallback"),
|
||||||
|
(15, "vacc-ddr"),
|
||||||
|
(20, "lib-vccl"),
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(target_file_path, "w", encoding="utf-8") as trace_file:
|
||||||
|
trace_file.write("[")
|
||||||
|
for tid, thread_name in entries:
|
||||||
|
line = f'{{"cat":"__metadata","pid":{group_name},"tid":{tid},"ts":0,"ph":"M","name":"thread_name","args":{{"name":"{thread_name}"}}}},\n'
|
||||||
|
trace_file.write(line)
|
||||||
|
|
||||||
|
timelines = []
|
||||||
|
for fpath in files:
|
||||||
|
with open(fpath, "r", encoding="utf-8") as file:
|
||||||
|
# timelines += [line.split(trace_token)[1] for line in file if trace_token in line]
|
||||||
|
for line in file:
|
||||||
|
if trace_token in line:
|
||||||
|
# 找到目标字符串,取其之后的内容(包括目标字符串)
|
||||||
|
timelines.append(line.split(trace_token)[1])
|
||||||
|
try:
|
||||||
|
json.loads(timelines[-1][:-2]) # remove ',\n'
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
# some log may not ends properly, just skip it
|
||||||
|
# chrome:://tracing stops reading following lines if an error encountered
|
||||||
|
# so must remove lines with error
|
||||||
|
timelines.pop()
|
||||||
|
|
||||||
|
for line in timelines[:-1]:
|
||||||
|
trace_file.write(line)
|
||||||
|
# fixing JSON format error by removing last comma in a list
|
||||||
|
trace_file.write(timelines[-1].replace(",\n", "\n"))
|
||||||
|
trace_file.write("]")
|
||||||
|
|
||||||
|
op_stats = run_stats_on_traces(timelines)
|
||||||
|
with open(
|
||||||
|
target_file_path.replace(".json", ".txt"), "w", encoding="utf-8"
|
||||||
|
) as op_stats_file:
|
||||||
|
for cat, tables in op_stats.items():
|
||||||
|
op_stats_file.write(f"{cat}".center(80, "-") + "\n")
|
||||||
|
op_stats_file.write(tables + "\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
def merge_schedule(out_file_prefix):
|
||||||
|
scheduler_data = []
|
||||||
|
for file in glob(f"{out_file_prefix}*.json"):
|
||||||
|
if file.endswith("schedule.json"):
|
||||||
|
continue
|
||||||
|
assert "rank" in file
|
||||||
|
rank = file.split("rank_")[-1].split("_")[0]
|
||||||
|
pid = None
|
||||||
|
with open(file, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
# set all schedule's pid to 0 and set all schedule's tid to rank id
|
||||||
|
if '"tid":0,' in line and "__metadata" not in line:
|
||||||
|
if pid is None:
|
||||||
|
pid = line.split('"pid":')[1].split(",")[0]
|
||||||
|
|
||||||
|
line = line.replace(f'"pid":{pid}', f'"pid":0')
|
||||||
|
line = line.replace('"tid":0,', f'"tid":{rank},')
|
||||||
|
scheduler_data.append(line)
|
||||||
|
|
||||||
|
out_file = f"{out_file_prefix}schedule.json"
|
||||||
|
with open(out_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write("[\n")
|
||||||
|
f.writelines(scheduler_data[:-1])
|
||||||
|
f.write(scheduler_data[-1].replace(",\n", "\n"))
|
||||||
|
f.write("]\n")
|
||||||
|
|
||||||
|
|
||||||
|
def scan_and_generate_trace(args, trace_token):
|
||||||
|
grouped_files = defaultdict(list)
|
||||||
|
for root, dirs, files in os.walk(args.log_dir):
|
||||||
|
for filename in files:
|
||||||
|
fpath = os.path.join(root, filename)
|
||||||
|
file_size = os.path.getsize(fpath)
|
||||||
|
if file_size != 0:
|
||||||
|
group_name = filename.rsplit("_", 1)[1].split(".")[0]
|
||||||
|
grouped_files[group_name].append(fpath)
|
||||||
|
pool_args = []
|
||||||
|
for group_name, files in grouped_files.items():
|
||||||
|
rank = get_rank_info(files)
|
||||||
|
out_file = f"{args.out_file_prefix}rank_{rank}_{group_name}.json"
|
||||||
|
pool_args.append((files, out_file, group_name, trace_token))
|
||||||
|
|
||||||
|
with Pool(len(grouped_files)) as p:
|
||||||
|
p.map(extract_traces, pool_args)
|
||||||
|
|
||||||
|
if args.merge_schedule:
|
||||||
|
merge_schedule(args.out_file_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
TRACE_TOKEN = "LOG_TRACE:"
|
||||||
|
|
||||||
|
current_file_path = os.path.abspath(__file__)
|
||||||
|
parent_directory = os.path.dirname(os.path.dirname(current_file_path))
|
||||||
|
find_directory = os.path.join(parent_directory, "log")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-dir", default=find_directory, type=str, help="directory of log files"
|
||||||
|
)
|
||||||
|
parser.add_argument("--out-file-prefix", default="timeline_", type=str)
|
||||||
|
parser.add_argument("--merge-schedule", action="store_true")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
scan_and_generate_trace(args, TRACE_TOKEN)
|
||||||
|
print("Scan and trace generation done!")
|
||||||
151
vacc_tools/memory_analyzer.py
Normal file
151
vacc_tools/memory_analyzer.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import fields
|
||||||
|
from typing import Dict, Tuple, List, Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
NUM_BYTES_IN_MB = 1024**2
|
||||||
|
NUM_BYTES_IN_GB = 1024**3
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryAnalyzer:
|
||||||
|
def __init__(
|
||||||
|
self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None
|
||||||
|
):
|
||||||
|
"""This memory usage analyzer will be mostly acurate only if you initialize
|
||||||
|
at the beginning and insert `get_memory_usage_in_gb` at the end of your
|
||||||
|
forward pass.
|
||||||
|
|
||||||
|
NOTE: It will have negative impact if not properly used as it stores
|
||||||
|
activations of every nn.Module's forward function and relies on user to
|
||||||
|
reset it everytime the forward pass ends.
|
||||||
|
|
||||||
|
Limitations:
|
||||||
|
1. does not work with customized operators
|
||||||
|
2. does not work with functional operators
|
||||||
|
3. it approximates activation as nn.Module.forward's output (if it's
|
||||||
|
inside the graph requires gradients), so it may not be exactly accurate.
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
self.activ_addrs = set()
|
||||||
|
self.activ_memory = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_activation(x):
|
||||||
|
return torch.is_tensor(x) and x.requires_grad and x.device != "cpu"
|
||||||
|
|
||||||
|
def _get_weight_grads_addrs(self):
|
||||||
|
weights = set([p.untyped_storage().data_ptr() for p in self.model.parameters()])
|
||||||
|
grads = set(
|
||||||
|
[
|
||||||
|
p.grad.untyped_storage().data_ptr()
|
||||||
|
for p in self.model.parameters()
|
||||||
|
if p.grad is not None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return weights.union(grads)
|
||||||
|
|
||||||
|
def pack_hook(self):
|
||||||
|
def _pack_hook(x):
|
||||||
|
if self._is_activation(x):
|
||||||
|
weight_grads = self._get_weight_grads_addrs()
|
||||||
|
# NOTE: storage is more accurate than using x.nelement() * x.element_size()
|
||||||
|
data_ptr = x.untyped_storage().data_ptr()
|
||||||
|
if data_ptr not in weight_grads and data_ptr not in self.activ_addrs:
|
||||||
|
self.activ_addrs.add(data_ptr)
|
||||||
|
self.activ_memory += x.untyped_storage().size()
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _pack_hook
|
||||||
|
|
||||||
|
def unpack_hook(self):
|
||||||
|
def _unpack_hook(x):
|
||||||
|
if self._is_activation(x):
|
||||||
|
weight_grads = self._get_weight_grads_addrs()
|
||||||
|
data_ptr = x.untyped_storage().data_ptr()
|
||||||
|
if data_ptr not in weight_grads and data_ptr in self.activ_addrs:
|
||||||
|
self.activ_addrs.remove(data_ptr)
|
||||||
|
self.activ_memory -= x.untyped_storage().size()
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _unpack_hook
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def record_activation(self):
|
||||||
|
with torch.autograd.graph.saved_tensors_hooks(
|
||||||
|
self.pack_hook(), self.unpack_hook()
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_weight_memory(model: torch.nn.Module):
|
||||||
|
weights = [
|
||||||
|
p.nelement() * p.element_size()
|
||||||
|
for p in model.parameters()
|
||||||
|
if p.device != "cpu"
|
||||||
|
]
|
||||||
|
return sum(weights)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_gradient_memory(model: torch.nn.Module):
|
||||||
|
grads = [
|
||||||
|
p.grad.nelement() * p.grad.element_size()
|
||||||
|
for p in model.parameters()
|
||||||
|
if p.grad is not None and p.grad.device != "cpu"
|
||||||
|
]
|
||||||
|
return sum(grads)
|
||||||
|
|
||||||
|
def _sum_activation_memory(self):
|
||||||
|
return self.activ_memory
|
||||||
|
|
||||||
|
def get_optimizer_state_memory(self):
|
||||||
|
if isinstance(self.optimizer, torch.optim.AdamW):
|
||||||
|
params = sum(
|
||||||
|
[
|
||||||
|
p.nelement() * p.element_size()
|
||||||
|
for pg in self.optimizer.param_groups
|
||||||
|
for p in pg["params"]
|
||||||
|
if torch.is_tensor(p) and p.device != "cpu"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for state in self.optimizer.state.values():
|
||||||
|
params += sum(
|
||||||
|
[
|
||||||
|
v.nelement() * v.element_size()
|
||||||
|
for k, v in state.items()
|
||||||
|
if torch.is_tensor(v) and v.device != "cpu"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _get_memory_usage(self) -> Tuple[int, int, int, int]:
|
||||||
|
return (
|
||||||
|
self.get_weight_memory(self.model),
|
||||||
|
self.get_gradient_memory(self.model),
|
||||||
|
self._sum_activation_memory(),
|
||||||
|
self.get_optimizer_state_memory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_memory_usage_in_gb(self) -> str:
|
||||||
|
w, g, a, opt = self._get_memory_usage()
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_GB:.3f} GB, "
|
||||||
|
f"weight: {w / NUM_BYTES_IN_GB:.3f} GB, "
|
||||||
|
f"gradient: {g / NUM_BYTES_IN_GB:.3f} GB, "
|
||||||
|
f"activation: {a / NUM_BYTES_IN_GB:.3f} GB, "
|
||||||
|
f"optimizer states: {opt / NUM_BYTES_IN_GB:.3f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_memory_usage_in_mb(self) -> str:
|
||||||
|
w, g, a, opt = self._get_memory_usage()
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_MB:.2f} MB, "
|
||||||
|
f"weight: {w / NUM_BYTES_IN_MB:.2f} MB, "
|
||||||
|
f"gradient: {g / NUM_BYTES_IN_MB:.2f} MB, "
|
||||||
|
f"activation: {a / NUM_BYTES_IN_MB:.2f} MB, "
|
||||||
|
f"optimizer states: {opt / NUM_BYTES_IN_MB:.2f} MB"
|
||||||
|
)
|
||||||
65
vacc_tools/parse_vacc_log_for_tracing.py
Normal file
65
vacc_tools/parse_vacc_log_for_tracing.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
|
||||||
|
log_tag = "LOG_TRACE:"
|
||||||
|
tid_names = [
|
||||||
|
(0, "module"),
|
||||||
|
(1, "megatron"),
|
||||||
|
(2, "deepspeed"),
|
||||||
|
(10, "vacc-odsp"),
|
||||||
|
(11, "vacc-dlc"),
|
||||||
|
(12, "vacc-vccl"),
|
||||||
|
(13, "vacc-cpu"),
|
||||||
|
(14, "vacc-cpu_fallback"),
|
||||||
|
(15, "vacc-ddr"),
|
||||||
|
(20, "lib-vccl"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_files_of_process(args):
|
||||||
|
pid, in_files = args
|
||||||
|
out_file = "trace_" + pid + ".json"
|
||||||
|
with open(out_file, "w", encoding="utf-8") as new_file:
|
||||||
|
metadata_lines = [
|
||||||
|
f'{{"name": "thread_name","ph": "M","pid": {pid},"tid": {tid},"args": {{"name": "{name}"}}}},'
|
||||||
|
for tid, name in tid_names
|
||||||
|
]
|
||||||
|
new_file.write("[\n")
|
||||||
|
new_file.write("\n".join(metadata_lines))
|
||||||
|
new_file.write("\n")
|
||||||
|
for file_path in in_files:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
|
for line in file:
|
||||||
|
if log_tag in line:
|
||||||
|
new_line = line.split(log_tag, 1)[1].strip()
|
||||||
|
new_file.write(new_line + "\n")
|
||||||
|
new_file.write("]")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_directory(directory):
|
||||||
|
pro_files = defaultdict(list)
|
||||||
|
for dirpath, dirnames, filenames in os.walk(directory):
|
||||||
|
for filename in filenames:
|
||||||
|
file_path = os.path.join(dirpath, filename)
|
||||||
|
if filename.startswith("vacc") and os.path.getsize(file_path) != 0:
|
||||||
|
pid = filename.rsplit("_", 1)[1].split(".")[0]
|
||||||
|
pro_files[pid].append(file_path)
|
||||||
|
|
||||||
|
args = []
|
||||||
|
for pid, in_files in pro_files.items():
|
||||||
|
args.append((pid, in_files))
|
||||||
|
|
||||||
|
with Pool() as p:
|
||||||
|
p.map(parse_files_of_process, args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="parse vacc log files and generate trace files"
|
||||||
|
)
|
||||||
|
parser.add_argument("directory", type=str, help="log directory to parse")
|
||||||
|
args = parser.parse_args()
|
||||||
|
parse_directory(args.directory)
|
||||||
329
vacc_tools/trace_logger.py
Normal file
329
vacc_tools/trace_logger.py
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
"""
|
||||||
|
This module provides mechanisms for tracing torch's module and function's execution,
|
||||||
|
and output the trace into a json file.
|
||||||
|
|
||||||
|
User needs to set environmental variable `LOG_TRAIN_SCHEDULE=1` to enable tracing.
|
||||||
|
If not, no trace will be applied.
|
||||||
|
|
||||||
|
Inside your module, create your module's tracer functions by using `get_trace_api`.
|
||||||
|
You will get three functions:
|
||||||
|
* `@trace_time(name)`: decorator to trace the execution of a function.
|
||||||
|
```python
|
||||||
|
@trace_time("my_func")
|
||||||
|
def my_func(x):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
* `@trace_autograd_function()`: decorator to trace the execution of forward
|
||||||
|
and backward of a user defined `torch.autograd.Function` operator.
|
||||||
|
```python
|
||||||
|
@trace_autograd_function()
|
||||||
|
class MyAutogradFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
...
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
* `register_module_trace()`: function to register trace a model (`nn.Module`),
|
||||||
|
it applies traces recursively to a torch model by enumerating all nn.Module
|
||||||
|
and register tracer to their forward and backward function. Only applying to
|
||||||
|
top level nn.Module is recommended.
|
||||||
|
```python
|
||||||
|
model = Model()
|
||||||
|
register_module_trace(model)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from datetime import datetime
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
|
||||||
|
MODULE_TID = {"megatron": 1, "deepspeed": 2, "nn.Module": 3, "ram": 100}
|
||||||
|
|
||||||
|
# pylint: disable=missing-docstring
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TraceEntry:
|
||||||
|
name: str
|
||||||
|
cat: str
|
||||||
|
pid: int
|
||||||
|
tid: int
|
||||||
|
ts: int
|
||||||
|
ph: str
|
||||||
|
args: str = None
|
||||||
|
|
||||||
|
def to_json_str(self):
|
||||||
|
d = asdict(self)
|
||||||
|
if self.args is None:
|
||||||
|
d.pop("args")
|
||||||
|
return json.dumps(d, separators=(",", ": "))
|
||||||
|
|
||||||
|
|
||||||
|
class LogFiles:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.loggers = {}
|
||||||
|
|
||||||
|
def get(self, file_prefix, rank, pid):
|
||||||
|
os.makedirs("log", exist_ok=True)
|
||||||
|
fpath = f"log/{file_prefix}-rank-{rank}_{pid}.txt"
|
||||||
|
if not fpath in self.loggers:
|
||||||
|
self.loggers[fpath] = open(fpath, "w")
|
||||||
|
return self.loggers[fpath]
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
for f in self.loggers.values():
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
def trace_logger_enabled() -> bool:
|
||||||
|
return (
|
||||||
|
"LOG_TRAIN_SCHEDULE" in os.environ and os.environ["LOG_TRAIN_SCHEDULE"] == "1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TraceLogger:
|
||||||
|
_log_files = LogFiles()
|
||||||
|
|
||||||
|
def __init__(self, category, tid=None, file_prefix=None) -> None:
|
||||||
|
self.enabled = trace_logger_enabled()
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
self.pid = os.getpid()
|
||||||
|
self.logger = None
|
||||||
|
self.cat = category
|
||||||
|
self._traces = {}
|
||||||
|
self.global_rank = 0
|
||||||
|
if tid is None:
|
||||||
|
self.tid = MODULE_TID.get(category, 1000)
|
||||||
|
else:
|
||||||
|
self.tid = tid
|
||||||
|
self.file_prefix = file_prefix if file_prefix is not None else self.cat
|
||||||
|
|
||||||
|
self.registered_modules = []
|
||||||
|
|
||||||
|
def _creat_logger(self) -> None:
|
||||||
|
# delay creating logger file until first log call,
|
||||||
|
# since torch.distributed may not be ready yet
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
self.global_rank = torch.distributed.get_rank()
|
||||||
|
self.logger = TraceLogger._log_files.get(
|
||||||
|
self.file_prefix, self.global_rank, self.pid
|
||||||
|
)
|
||||||
|
|
||||||
|
def begin_trace(self, name, memory=False) -> None:
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.logger is None:
|
||||||
|
self._creat_logger()
|
||||||
|
assert self.logger is not None
|
||||||
|
|
||||||
|
name = f"{name}" # convert it to str to ensure json serializable
|
||||||
|
|
||||||
|
start_time = int(datetime.now().timestamp() * 1e6) # in us
|
||||||
|
trace = TraceEntry(name, self.cat, self.pid, self.tid, start_time, "B")
|
||||||
|
|
||||||
|
mem_trace = self._get_memory(start_time) if memory else None
|
||||||
|
|
||||||
|
if name not in self._traces:
|
||||||
|
self._traces[name] = [(trace, mem_trace)]
|
||||||
|
else: # in case call to the function is nested
|
||||||
|
self._traces[name].append((trace, mem_trace))
|
||||||
|
|
||||||
|
def end_trace(self, name, flush=False, memory=False) -> None:
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
name = f"{name}" # convert it to str to ensure json serializable
|
||||||
|
|
||||||
|
assert self.logger is not None, "begin_trace should be called before end_trace"
|
||||||
|
assert name in self._traces, "begin_trace should be called before end_trace"
|
||||||
|
|
||||||
|
start_trace, start_mem = self._traces[name].pop()
|
||||||
|
if start_mem is not None:
|
||||||
|
self.logger.write(f"LOG_TRACE:{start_mem.to_json_str()},\n")
|
||||||
|
self.logger.write(f"LOG_TRACE:{start_trace.to_json_str()},\n")
|
||||||
|
|
||||||
|
end_time = int(datetime.now().timestamp() * 1e6) # in us
|
||||||
|
args = {"value(us)": end_time - start_trace.ts}
|
||||||
|
trace = TraceEntry(name, self.cat, self.pid, self.tid, end_time, "E", args)
|
||||||
|
self.logger.write(f"LOG_TRACE:{trace.to_json_str()},\n")
|
||||||
|
|
||||||
|
if memory:
|
||||||
|
mem_trace = self._get_memory(end_time)
|
||||||
|
self.logger.write(f"LOG_TRACE:{mem_trace.to_json_str()},\n")
|
||||||
|
|
||||||
|
if flush:
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
if self.logger is not None:
|
||||||
|
self.logger.flush()
|
||||||
|
|
||||||
|
def _get_memory(self, timestamp):
|
||||||
|
args = {"value": torch.vacc.memory_allocated(self.global_rank)}
|
||||||
|
mem_trace = TraceEntry(
|
||||||
|
"memory", "memory", self.pid, MODULE_TID["ram"], timestamp, "C", args
|
||||||
|
)
|
||||||
|
return mem_trace
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _trace_time(name, logger_inst, memory=False, flush=False):
|
||||||
|
if not logger_inst.enabled:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
logger_inst.begin_trace(name)
|
||||||
|
yield
|
||||||
|
logger_inst.end_trace(name, flush=flush)
|
||||||
|
|
||||||
|
|
||||||
|
SKIPED_MODULES = []
|
||||||
|
|
||||||
|
|
||||||
|
def _register_module_trace(
|
||||||
|
module: torch.nn.Module, logger_inst, flush: bool = True, forward_only=False
|
||||||
|
):
|
||||||
|
if not logger_inst.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(module, torch.nn.Module):
|
||||||
|
return
|
||||||
|
|
||||||
|
def _register(m):
|
||||||
|
module_name = f"{type(m).__name__}"
|
||||||
|
if module_name == "WrapName":
|
||||||
|
module_name = f"{type(m.forward_func.__self__).__name__}"
|
||||||
|
|
||||||
|
if module_name in SKIPED_MODULES:
|
||||||
|
return
|
||||||
|
|
||||||
|
forward_name = module_name + ".forward"
|
||||||
|
|
||||||
|
m.register_forward_pre_hook(
|
||||||
|
lambda m, inp: logger_inst.begin_trace(forward_name, memory=True)
|
||||||
|
)
|
||||||
|
m.register_forward_hook(
|
||||||
|
lambda m, inp, out: logger_inst.end_trace(forward_name, memory=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not forward_only:
|
||||||
|
backward_name = module_name + ".backward"
|
||||||
|
|
||||||
|
m.register_full_backward_pre_hook(
|
||||||
|
lambda m, grad_out: logger_inst.begin_trace(backward_name, memory=True)
|
||||||
|
)
|
||||||
|
m.register_full_backward_hook(
|
||||||
|
lambda m, grad_in, grad_out: logger_inst.end_trace(
|
||||||
|
backward_name, memory=True, flush=flush
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in module.modules():
|
||||||
|
if m in logger_inst.registered_modules:
|
||||||
|
print(
|
||||||
|
f"module `{m}` already registered, skip applying trace on same module multiple times."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
_register(m)
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_autograd_function(logger_inst):
|
||||||
|
def decorator(cls):
|
||||||
|
if not issubclass(cls, torch.autograd.Function):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
def _apply(name, method):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with _trace_time(name, logger_inst=logger_inst, memory=True):
|
||||||
|
result = method(*args, **kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
for attr in ["forward", "backward"]:
|
||||||
|
setattr(cls, attr, _apply(cls.__name__ + "." + attr, getattr(cls, attr)))
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _register_optimizer_trace(
|
||||||
|
optimizer: torch.optim.Optimizer, logger_inst, flush: bool = True
|
||||||
|
):
|
||||||
|
if not logger_inst.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
trace_name = f"{type(optimizer).__name__}.step"
|
||||||
|
|
||||||
|
if isinstance(optimizer, torch.optim.Optimizer):
|
||||||
|
optimizer.register_step_pre_hook(
|
||||||
|
lambda m, *args, **kwargs: logger_inst.begin_trace(trace_name, memory=True)
|
||||||
|
)
|
||||||
|
optimizer.register_step_post_hook(
|
||||||
|
lambda m, *args, **kwargs: logger_inst.end_trace(
|
||||||
|
trace_name, memory=True, flush=flush
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif hasattr(optimizer, "step") and callable(optimizer.step):
|
||||||
|
# customized optimzier does not has step hooks
|
||||||
|
|
||||||
|
original_step = optimizer.step
|
||||||
|
|
||||||
|
def traced_step(*args, **kwargs):
|
||||||
|
logger_inst.begin_trace(trace_name, memory=True)
|
||||||
|
result = original_step(*args, **kwargs)
|
||||||
|
logger_inst.end_trace(trace_name, memory=True, flush=flush)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Replace the step method with the new function
|
||||||
|
optimizer.step = traced_step
|
||||||
|
else:
|
||||||
|
# unknown optimizer or wrong instance pass to this function.
|
||||||
|
pass
|
||||||
|
|
||||||
|
if hasattr(optimizer, "reduce_gradients") and callable(optimizer.reduce_gradients):
|
||||||
|
trace_name = f"{type(optimizer).__name__}.reduce_gradients"
|
||||||
|
original_reduce = optimizer.reduce_gradients
|
||||||
|
|
||||||
|
def traced_reduce(*args, **kwargs):
|
||||||
|
logger_inst.begin_trace(trace_name, memory=True)
|
||||||
|
result = original_reduce(*args, **kwargs)
|
||||||
|
logger_inst.end_trace(trace_name, memory=True, flush=flush)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Replace the step method with the new function
|
||||||
|
optimizer.reduce_gradients = traced_reduce
|
||||||
|
|
||||||
|
|
||||||
|
def get_trace_api(name="nn.Module"):
|
||||||
|
"""generate module execution trace APIs for a given module name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): module name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (trace_time, register_module_trace, trace_autograd_function)
|
||||||
|
Usage of these three functions is describted in the docstring of this module
|
||||||
|
"""
|
||||||
|
_trace_logger = TraceLogger(name)
|
||||||
|
|
||||||
|
return (
|
||||||
|
partial(_trace_time, logger_inst=_trace_logger),
|
||||||
|
partial(_register_module_trace, logger_inst=_trace_logger),
|
||||||
|
partial(_trace_autograd_function, logger_inst=_trace_logger),
|
||||||
|
partial(_register_optimizer_trace, logger_inst=_trace_logger),
|
||||||
|
)
|
||||||
BIN
vllm/_C.abi3.so
Normal file
BIN
vllm/_C.abi3.so
Normal file
Binary file not shown.
102
vllm/__init__.py
Normal file
102
vllm/__init__.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||||
|
|
||||||
|
# The version.py should be independent library, and we always import the
|
||||||
|
# version library first. Such assumption is critical for some customization.
|
||||||
|
from .version import __version__, __version_tuple__ # isort:skip
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
|
# The environment variables override should be imported before any other
|
||||||
|
# modules to ensure that the environment variables are set before any
|
||||||
|
# other modules are imported.
|
||||||
|
import vllm.env_override # noqa: F401
|
||||||
|
|
||||||
|
MODULE_ATTRS = {
|
||||||
|
"bc_linter_skip": "._bc_linter:bc_linter_skip",
|
||||||
|
"bc_linter_include": "._bc_linter:bc_linter_include",
|
||||||
|
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
|
||||||
|
"EngineArgs": ".engine.arg_utils:EngineArgs",
|
||||||
|
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
|
||||||
|
"LLMEngine": ".engine.llm_engine:LLMEngine",
|
||||||
|
"LLM": ".entrypoints.llm:LLM",
|
||||||
|
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
|
||||||
|
"PromptType": ".inputs:PromptType",
|
||||||
|
"TextPrompt": ".inputs:TextPrompt",
|
||||||
|
"TokensPrompt": ".inputs:TokensPrompt",
|
||||||
|
"ModelRegistry": ".model_executor.models:ModelRegistry",
|
||||||
|
"SamplingParams": ".sampling_params:SamplingParams",
|
||||||
|
"PoolingParams": ".pooling_params:PoolingParams",
|
||||||
|
"ClassificationOutput": ".outputs:ClassificationOutput",
|
||||||
|
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
|
||||||
|
"CompletionOutput": ".outputs:CompletionOutput",
|
||||||
|
"EmbeddingOutput": ".outputs:EmbeddingOutput",
|
||||||
|
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
|
||||||
|
"PoolingOutput": ".outputs:PoolingOutput",
|
||||||
|
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
|
||||||
|
"RequestOutput": ".outputs:RequestOutput",
|
||||||
|
"ScoringOutput": ".outputs:ScoringOutput",
|
||||||
|
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
|
||||||
|
}
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.entrypoints.llm import LLM
|
||||||
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
|
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.outputs import (ClassificationOutput,
|
||||||
|
ClassificationRequestOutput, CompletionOutput,
|
||||||
|
EmbeddingOutput, EmbeddingRequestOutput,
|
||||||
|
PoolingOutput, PoolingRequestOutput,
|
||||||
|
RequestOutput, ScoringOutput,
|
||||||
|
ScoringRequestOutput)
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
from ._bc_linter import bc_linter_include, bc_linter_skip
|
||||||
|
else:
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> typing.Any:
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
if name in MODULE_ATTRS:
|
||||||
|
module_name, attr_name = MODULE_ATTRS[name].split(":")
|
||||||
|
module = import_module(module_name, __package__)
|
||||||
|
return getattr(module, attr_name)
|
||||||
|
else:
|
||||||
|
raise AttributeError(
|
||||||
|
f'module {__package__} has no attribute {name}')
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"bc_linter_skip",
|
||||||
|
"bc_linter_include",
|
||||||
|
"__version_tuple__",
|
||||||
|
"LLM",
|
||||||
|
"ModelRegistry",
|
||||||
|
"PromptType",
|
||||||
|
"TextPrompt",
|
||||||
|
"TokensPrompt",
|
||||||
|
"SamplingParams",
|
||||||
|
"RequestOutput",
|
||||||
|
"CompletionOutput",
|
||||||
|
"PoolingOutput",
|
||||||
|
"PoolingRequestOutput",
|
||||||
|
"EmbeddingOutput",
|
||||||
|
"EmbeddingRequestOutput",
|
||||||
|
"ClassificationOutput",
|
||||||
|
"ClassificationRequestOutput",
|
||||||
|
"ScoringOutput",
|
||||||
|
"ScoringRequestOutput",
|
||||||
|
"LLMEngine",
|
||||||
|
"EngineArgs",
|
||||||
|
"AsyncLLMEngine",
|
||||||
|
"AsyncEngineArgs",
|
||||||
|
"initialize_ray_cluster",
|
||||||
|
"PoolingParams",
|
||||||
|
]
|
||||||
BIN
vllm/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_bc_linter.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/_bc_linter.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_custom_ops.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/_custom_ops.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_version.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/_version.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/beam_search.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/beam_search.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/connections.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/connections.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/env_override.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/env_override.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/envs.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/envs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logger.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/logger.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logits_process.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/logits_process.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logprobs.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/logprobs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/outputs.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/outputs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/pooling_params.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/pooling_params.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/sampling_params.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/sampling_params.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/scalar_type.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/scalar_type.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/sequence.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/sequence.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/tasks.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/tasks.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/test_utils.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/test_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/tracing.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/tracing.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/version.cpython-312.pyc
Normal file
BIN
vllm/__pycache__/version.cpython-312.pyc
Normal file
Binary file not shown.
59
vllm/_bc_linter.py
Normal file
59
vllm/_bc_linter.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# vllm/_bc_linter.py
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Callable, TypeVar, overload
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_skip(obj: T) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
|
||||||
|
"""
|
||||||
|
No-op decorator to mark symbols/files for BC-linter suppression.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@bc_linter_skip
|
||||||
|
def legacy_api(...): ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _wrap(x: T) -> T:
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _wrap if obj is None else obj
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_include(obj: T) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
@bc_linter_include
|
||||||
|
def public_api(...): ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _wrap(x: T) -> T:
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _wrap if obj is None else obj
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["bc_linter_skip", "bc_linter_include"]
|
||||||
2044
vllm/_custom_ops.py
Normal file
2044
vllm/_custom_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
393
vllm/_ipex_ops.py
Normal file
393
vllm/_ipex_ops.py
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
except ImportError as e:
|
||||||
|
logger.debug("Import error msg: %s", e.msg)
|
||||||
|
|
||||||
|
|
||||||
|
class ipex_ops:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reshape_activation_tensor(
|
||||||
|
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
num = x.size(0)
|
||||||
|
d = x.size(1) // 2
|
||||||
|
x = x.reshape(num, 2, d)
|
||||||
|
x1, x2 = torch.chunk(x, chunks=2, dim=1)
|
||||||
|
x1 = x1.reshape(num, d)
|
||||||
|
x2 = x2.reshape(num, d)
|
||||||
|
return x1, x2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.silu_and_mul(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.gelu_and_mul(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.gelu_and_mul(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_new(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||||
|
ipex.llm.functional.gelu_quick(x, out)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def paged_attention_v1(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
num_heads = out.size(1)
|
||||||
|
num_queries_per_tokens = num_heads // num_kv_heads
|
||||||
|
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||||
|
out,
|
||||||
|
query.contiguous(),
|
||||||
|
key_cache.view_as(value_cache),
|
||||||
|
value_cache,
|
||||||
|
num_queries_per_tokens,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def paged_attention_v2(
|
||||||
|
out: torch.Tensor,
|
||||||
|
exp_sum: torch.Tensor,
|
||||||
|
max_logits: torch.Tensor,
|
||||||
|
tmp_out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
max_context_len: int,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
blocksparse_local_blocks: int = 0,
|
||||||
|
blocksparse_vert_stride: int = 0,
|
||||||
|
blocksparse_block_size: int = 64,
|
||||||
|
blocksparse_head_sliding_step: int = 0,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
num_heads = out.size(1)
|
||||||
|
num_queries_per_tokens = num_heads // num_kv_heads
|
||||||
|
ipex.llm.modules.PagedAttention.single_query_kv_attention(
|
||||||
|
out,
|
||||||
|
query.contiguous(),
|
||||||
|
key_cache.view_as(value_cache),
|
||||||
|
value_cache,
|
||||||
|
num_queries_per_tokens,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
context_lens,
|
||||||
|
block_size,
|
||||||
|
max_context_len,
|
||||||
|
alibi_slopes,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rotary_embedding(
|
||||||
|
positions: torch.Tensor, # [batch_size, seq_len]
|
||||||
|
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
|
||||||
|
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
|
||||||
|
head_size: int,
|
||||||
|
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
|
||||||
|
is_neox: bool,
|
||||||
|
) -> None:
|
||||||
|
rot_dim = cos_sin_cache.size(1)
|
||||||
|
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
||||||
|
head_size, cos_sin_cache,
|
||||||
|
is_neox, rot_dim)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
epsilon: float) -> torch.Tensor:
|
||||||
|
return ipex.llm.functional.rms_norm(input, weight, epsilon)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor, epsilon: float) -> None:
|
||||||
|
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
|
||||||
|
epsilon, True)
|
||||||
|
input.copy_(tmp)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def varlen_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
seqlen_q: torch.Tensor,
|
||||||
|
seqlen_k: torch.Tensor,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
pdropout: float,
|
||||||
|
softmax_scale: float,
|
||||||
|
zero_tensors: bool,
|
||||||
|
is_causal: bool,
|
||||||
|
return_softmax: bool,
|
||||||
|
gen_: torch.Generator,
|
||||||
|
window_size_left: float,
|
||||||
|
window_size_right: float,
|
||||||
|
logits_soft_cap: float,
|
||||||
|
) -> None:
|
||||||
|
if ipex.__version__.endswith("cpu"):
|
||||||
|
if logits_soft_cap != 0.0:
|
||||||
|
raise ValueError("IPEX CPU does not support logits_soft_cap")
|
||||||
|
assert alibi_slopes is None
|
||||||
|
assert window_size_left < 0 and window_size_right < 0
|
||||||
|
ipex.llm.functional.varlen_attention(query.contiguous(),
|
||||||
|
key.contiguous(),
|
||||||
|
value.contiguous(), out,
|
||||||
|
seqlen_q.int(),
|
||||||
|
seqlen_k.int(), max_seqlen_q,
|
||||||
|
max_seqlen_k, pdropout,
|
||||||
|
softmax_scale, zero_tensors,
|
||||||
|
is_causal, return_softmax,
|
||||||
|
gen_)
|
||||||
|
else: # XPU build
|
||||||
|
ipex.llm.functional.varlen_attention(
|
||||||
|
query.contiguous(), key.contiguous(), value.contiguous(), out,
|
||||||
|
seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q,
|
||||||
|
max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal,
|
||||||
|
return_softmax, gen_, window_size_left, window_size_right,
|
||||||
|
logits_soft_cap)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
|
) -> None:
|
||||||
|
assert kv_cache_dtype == "auto"
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slot_mapping)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reshape_and_cache_flash(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scale: Optional[torch.Tensor] = None,
|
||||||
|
v_scale: Optional[torch.Tensor] = None,
|
||||||
|
k_scale_float: float = 1.0,
|
||||||
|
v_scale_float: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
|
||||||
|
key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||||
|
k_scale_float, v_scale_float)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def flash_attn_varlen_func(
|
||||||
|
out: torch.Tensor,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
seqused_k: torch.Tensor, # we don't support this in ipex kernel
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
causal: bool,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
|
window_size: Optional[list[int]] = None,
|
||||||
|
softcap: Optional[float] = 0.0,
|
||||||
|
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||||
|
# The following parameters are not used in ipex kernel currently,
|
||||||
|
# we keep API compatible to CUDA's.
|
||||||
|
scheduler_metadata=None,
|
||||||
|
fa_version: int = 2,
|
||||||
|
q_descale=None,
|
||||||
|
k_descale=None,
|
||||||
|
v_descale=None,
|
||||||
|
num_splits=0,
|
||||||
|
s_aux: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
if cu_seqlens_k is None:
|
||||||
|
# cu_seqlens_k is not used in ipex kernel.
|
||||||
|
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
|
||||||
|
cu_seqlens_k = torch.cat([
|
||||||
|
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
|
||||||
|
cu_seqlens_k
|
||||||
|
]).to(torch.int32)
|
||||||
|
|
||||||
|
real_window_size: tuple[int, int]
|
||||||
|
if window_size is None:
|
||||||
|
real_window_size = (-1, -1)
|
||||||
|
else:
|
||||||
|
assert len(window_size) == 2
|
||||||
|
real_window_size = (window_size[0], window_size[1])
|
||||||
|
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
|
out,
|
||||||
|
q.contiguous(),
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
softmax_scale,
|
||||||
|
causal,
|
||||||
|
block_table,
|
||||||
|
alibi_slopes,
|
||||||
|
softcap=softcap,
|
||||||
|
window_size_left=real_window_size[0],
|
||||||
|
window_size_right=real_window_size[1],
|
||||||
|
k_scale=1.0,
|
||||||
|
v_scale=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_scheduler_metadata(
|
||||||
|
batch_size,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
num_heads_q,
|
||||||
|
num_heads_kv,
|
||||||
|
headdim,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
qkv_dtype=torch.bfloat16,
|
||||||
|
headdim_v=None,
|
||||||
|
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
||||||
|
cache_leftpad: Optional[torch.Tensor] = None,
|
||||||
|
page_size: Optional[int] = None,
|
||||||
|
max_seqlen_k_new=0,
|
||||||
|
causal=False,
|
||||||
|
window_size=(-1, -1), # -1 means infinite context window
|
||||||
|
has_softcap=False,
|
||||||
|
num_splits=0, # Can be tuned for speed
|
||||||
|
pack_gqa=None, # Can be tuned for speed
|
||||||
|
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||||
|
) -> None:
|
||||||
|
logger.warning_once(
|
||||||
|
"get_scheduler_metadata is not implemented for ipex_ops, "
|
||||||
|
"returning None.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(key_caches: list[torch.Tensor],
|
||||||
|
value_caches: list[torch.Tensor],
|
||||||
|
block_mapping: torch.Tensor) -> None:
|
||||||
|
torch.xpu.copy_blocks( # type: ignore
|
||||||
|
key_caches,
|
||||||
|
value_caches,
|
||||||
|
block_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||||
|
block_mapping: torch.Tensor) -> None:
|
||||||
|
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scaled_fp8_quant(
|
||||||
|
input: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
num_token_padding: Optional[int] = None,
|
||||||
|
scale_ub: Optional[torch.Tensor] = None,
|
||||||
|
use_per_token_if_dynamic: bool = False,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||||
|
|
||||||
|
This function is designed for both static and dynamic quantization:
|
||||||
|
If you provide the scale, it will use static scaling and if you omit
|
||||||
|
it, the scale will be determined dynamically. Currently, XPU platform
|
||||||
|
only supports dynamic quantization. The function also allows optional
|
||||||
|
padding of the output tensors for downstream kernels that will benefit
|
||||||
|
from padding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input tensor to be quantized to FP8
|
||||||
|
scale: Optional scaling factor for the FP8 quantization
|
||||||
|
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||||
|
per token case
|
||||||
|
num_token_padding: If specified, pad the first dimension
|
||||||
|
of the output to at least this value.
|
||||||
|
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||||
|
in the dynamic quantization case.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||||
|
scaling factor.
|
||||||
|
"""
|
||||||
|
# This code assumes batch_dim and num_tokens are flattened
|
||||||
|
assert (input.ndim == 2)
|
||||||
|
shape: Union[tuple[int, int], torch.Size] = input.shape
|
||||||
|
out_dtype: torch.dtype = current_platform.fp8_dtype()
|
||||||
|
if num_token_padding:
|
||||||
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
|
if output is None:
|
||||||
|
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||||
|
else:
|
||||||
|
assert num_token_padding is None, \
|
||||||
|
"padding not supported if output passed in"
|
||||||
|
assert output.dtype == out_dtype
|
||||||
|
assert scale is None, "only dynamic fp8 quantization supported on XPU"
|
||||||
|
assert not use_per_token_if_dynamic, (
|
||||||
|
"per token dynamic fp8 quantization not supported on XPU")
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
|
||||||
|
|
||||||
|
return output, scale
|
||||||
34
vllm/_version.py
Normal file
34
vllm/_version.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# file generated by setuptools-scm
|
||||||
|
# don't change, don't track in version control
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"__version_tuple__",
|
||||||
|
"version",
|
||||||
|
"version_tuple",
|
||||||
|
"__commit_id__",
|
||||||
|
"commit_id",
|
||||||
|
]
|
||||||
|
|
||||||
|
TYPE_CHECKING = False
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
||||||
|
COMMIT_ID = Union[str, None]
|
||||||
|
else:
|
||||||
|
VERSION_TUPLE = object
|
||||||
|
COMMIT_ID = object
|
||||||
|
|
||||||
|
version: str
|
||||||
|
__version__: str
|
||||||
|
__version_tuple__: VERSION_TUPLE
|
||||||
|
version_tuple: VERSION_TUPLE
|
||||||
|
commit_id: COMMIT_ID
|
||||||
|
__commit_id__: COMMIT_ID
|
||||||
|
|
||||||
|
__version__ = version = '0.11.0'
|
||||||
|
__version_tuple__ = version_tuple = (0, 11, 0)
|
||||||
|
|
||||||
|
__commit_id__ = commit_id = None
|
||||||
0
vllm/assets/__init__.py
Normal file
0
vllm/assets/__init__.py
Normal file
BIN
vllm/assets/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
vllm/assets/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
vllm/assets/__pycache__/base.cpython-312.pyc
Normal file
BIN
vllm/assets/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
45
vllm/assets/audio.py
Normal file
45
vllm/assets/audio.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
|
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
|
ASSET_DIR = "multimodal_asset"
|
||||||
|
|
||||||
|
AudioAssetName = Literal["winning_call", "mary_had_lamb"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AudioAsset:
|
||||||
|
name: AudioAssetName
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self) -> str:
|
||||||
|
return f"{self.name}.ogg"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
|
||||||
|
audio_path = get_vllm_public_assets(filename=self.filename,
|
||||||
|
s3_prefix=ASSET_DIR)
|
||||||
|
return librosa.load(audio_path, sr=None)
|
||||||
|
|
||||||
|
def get_local_path(self) -> Path:
|
||||||
|
return get_vllm_public_assets(filename=self.filename,
|
||||||
|
s3_prefix=ASSET_DIR)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self) -> str:
|
||||||
|
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
||||||
41
vllm/assets/base.py
Normal file
41
vllm/assets/base.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.connections import global_http_connection
|
||||||
|
|
||||||
|
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_dir() -> Path:
|
||||||
|
"""Get the path to the cache for storing downloaded assets."""
|
||||||
|
path = Path(envs.VLLM_ASSETS_CACHE)
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_vllm_public_assets(filename: str,
|
||||||
|
s3_prefix: Optional[str] = None) -> Path:
|
||||||
|
"""
|
||||||
|
Download an asset file from ``s3://vllm-public-assets``
|
||||||
|
and return the path to the downloaded file.
|
||||||
|
"""
|
||||||
|
asset_directory = get_cache_dir() / "vllm_public_assets"
|
||||||
|
asset_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
asset_path = asset_directory / filename
|
||||||
|
if not asset_path.exists():
|
||||||
|
if s3_prefix is not None:
|
||||||
|
filename = s3_prefix + "/" + filename
|
||||||
|
global_http_connection.download_file(
|
||||||
|
f"{VLLM_S3_BUCKET_URL}/{filename}",
|
||||||
|
asset_path,
|
||||||
|
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT)
|
||||||
|
|
||||||
|
return asset_path
|
||||||
50
vllm/assets/image.py
Normal file
50
vllm/assets/image.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .base import get_vllm_public_assets
|
||||||
|
|
||||||
|
VLM_IMAGES_DIR = "vision_model_images"
|
||||||
|
|
||||||
|
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato",
|
||||||
|
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
|
||||||
|
"Grayscale_8bits_palette_sample_image",
|
||||||
|
"1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300",
|
||||||
|
"231-200x300", "27-500x500", "17-150x600",
|
||||||
|
"handelsblatt-preview", "paper-11"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ImageAsset:
|
||||||
|
name: ImageAssetName
|
||||||
|
|
||||||
|
def get_path(self, ext: str) -> Path:
|
||||||
|
"""
|
||||||
|
Return s3 path for given image.
|
||||||
|
"""
|
||||||
|
return get_vllm_public_assets(filename=f"{self.name}.{ext}",
|
||||||
|
s3_prefix=VLM_IMAGES_DIR)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pil_image(self, ext="jpg") -> Image.Image:
|
||||||
|
|
||||||
|
image_path = self.get_path(ext)
|
||||||
|
return Image.open(image_path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_embeds(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Image embeddings, only used for testing purposes with llava 1.5.
|
||||||
|
"""
|
||||||
|
image_path = self.get_path('pt')
|
||||||
|
return torch.load(image_path, map_location="cpu", weights_only=True)
|
||||||
|
|
||||||
|
def read_bytes(self, ext: str) -> bytes:
|
||||||
|
p = Path(self.get_path(ext))
|
||||||
|
return p.read_bytes()
|
||||||
145
vllm/assets/video.py
Normal file
145
vllm/assets/video.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, ClassVar, Literal, Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
|
from .base import get_cache_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def download_video_asset(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Download and open an image from huggingface
|
||||||
|
repo: raushan-testing-hf/videos-test
|
||||||
|
"""
|
||||||
|
video_directory = get_cache_dir() / "video-example-data"
|
||||||
|
video_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
video_path = video_directory / filename
|
||||||
|
video_path_str = str(video_path)
|
||||||
|
if not video_path.exists():
|
||||||
|
video_path_str = hf_hub_download(
|
||||||
|
repo_id="raushan-testing-hf/videos-test",
|
||||||
|
filename=filename,
|
||||||
|
repo_type="dataset",
|
||||||
|
cache_dir=video_directory,
|
||||||
|
)
|
||||||
|
return video_path_str
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"Could not open video file {path}")
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
num_frames = num_frames if num_frames > 0 else total_frames
|
||||||
|
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||||
|
for idx in range(total_frames):
|
||||||
|
ok = cap.grab() # next img
|
||||||
|
if not ok:
|
||||||
|
break
|
||||||
|
if idx in frame_indices: # only decompress needed
|
||||||
|
ret, frame = cap.retrieve()
|
||||||
|
if ret:
|
||||||
|
# OpenCV uses BGR format, we need to convert it to RGB
|
||||||
|
# for PIL and transformers compatibility
|
||||||
|
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
|
frames = np.stack(frames)
|
||||||
|
if len(frames) < num_frames:
|
||||||
|
raise ValueError(f"Could not read enough frames from video file {path}"
|
||||||
|
f" (expected {num_frames} frames, got {len(frames)})")
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_pil_images_list(path: str,
|
||||||
|
num_frames: int = -1) -> list[Image.Image]:
|
||||||
|
frames = video_to_ndarrays(path, num_frames)
|
||||||
|
return [Image.fromarray(frame) for frame in frames]
|
||||||
|
|
||||||
|
|
||||||
|
def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"Could not open video file {path}")
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
duration = total_frames / fps if fps > 0 else 0
|
||||||
|
|
||||||
|
if num_frames == -1 or num_frames > total_frames:
|
||||||
|
num_frames = total_frames
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"total_num_frames": num_frames,
|
||||||
|
"fps": fps,
|
||||||
|
"duration": duration,
|
||||||
|
"video_backend": "opencv",
|
||||||
|
"frames_indices": list(range(num_frames)),
|
||||||
|
# extra field used to control hf processor's video
|
||||||
|
# sampling behavior
|
||||||
|
"do_sample_frames": num_frames == total_frames,
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
VideoAssetName = Literal["baby_reading"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class VideoAsset:
|
||||||
|
name: VideoAssetName
|
||||||
|
num_frames: int = -1
|
||||||
|
|
||||||
|
_NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = {
|
||||||
|
"baby_reading": "sample_demo_1.mp4",
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self) -> str:
|
||||||
|
return self._NAME_TO_FILE[self.name]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_path(self) -> str:
|
||||||
|
return download_video_asset(self.filename)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pil_images(self) -> list[Image.Image]:
|
||||||
|
ret = video_to_pil_images_list(self.video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def np_ndarrays(self) -> npt.NDArray:
|
||||||
|
ret = video_to_ndarrays(self.video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> dict[str, Any]:
|
||||||
|
ret = video_get_metadata(self.video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
|
||||||
|
"""
|
||||||
|
Read audio data from the video asset, used in Qwen2.5-Omni examples.
|
||||||
|
|
||||||
|
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
|
||||||
|
"""
|
||||||
|
return librosa.load(self.video_path, sr=sampling_rate)[0]
|
||||||
15
vllm/attention/__init__.py
Normal file
15
vllm/attention/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata, AttentionType)
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Attention",
|
||||||
|
"AttentionBackend",
|
||||||
|
"AttentionMetadata",
|
||||||
|
"AttentionType",
|
||||||
|
"get_attn_backend",
|
||||||
|
]
|
||||||
0
vllm/attention/backends/__init__.py
Normal file
0
vllm/attention/backends/__init__.py
Normal file
204
vllm/attention/backends/abstract.py
Normal file
204
vllm/attention/backends/abstract.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionType:
|
||||||
|
"""
|
||||||
|
Attention type.
|
||||||
|
Use string to be compatible with `torch.compile`.
|
||||||
|
"""
|
||||||
|
DECODER = "decoder"
|
||||||
|
"""Decoder attention between previous layer Q/K/V."""
|
||||||
|
ENCODER = "encoder"
|
||||||
|
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
||||||
|
ENCODER_ONLY = "encoder_only"
|
||||||
|
"""Encoder attention between previous layer Q/K/V."""
|
||||||
|
ENCODER_DECODER = "encoder_decoder"
|
||||||
|
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBackend(ABC):
|
||||||
|
"""Abstract class for attention backends."""
|
||||||
|
# For some attention backends, we allocate an output tensor before
|
||||||
|
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||||
|
# makes sure the output tensor is allocated inside the cudagraph.
|
||||||
|
accept_output_buffer: bool = False
|
||||||
|
|
||||||
|
# Whether this backend supports receiving pre-quantized query input.
|
||||||
|
# If True, the attention layer will handle query quantization instead
|
||||||
|
# of the backend, allowing torch.compile to fuse quantization with
|
||||||
|
# previous operations.
|
||||||
|
# Needs to be worked through for all backends
|
||||||
|
# https://github.com/vllm-project/vllm/issues/25584
|
||||||
|
supports_quant_query_input: bool = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_impl_cls() -> Type["AttentionImpl"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||||
|
return cls.get_metadata_cls()(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
cache_dtype_str: str = "auto",
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_stride_order() -> Tuple[int, ...]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def full_cls_name(cls) -> tuple[str, str]:
|
||||||
|
return (cls.__module__, cls.__qualname__)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionMetadata:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=AttentionMetadata)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLayer(Protocol):
|
||||||
|
|
||||||
|
_q_scale: torch.Tensor
|
||||||
|
_k_scale: torch.Tensor
|
||||||
|
_v_scale: torch.Tensor
|
||||||
|
_q_scale_float: float
|
||||||
|
_k_scale_float: float
|
||||||
|
_v_scale_float: float
|
||||||
|
_prob_scale: torch.Tensor
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionImpl(ABC, Generic[T]):
|
||||||
|
|
||||||
|
# Whether the attention impl can return the softmax lse for decode.
|
||||||
|
# Some features like decode context parallelism require the softmax lse.
|
||||||
|
can_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
# some attention backends might not always want to return lse
|
||||||
|
# even if they can return lse (for efficiency reasons)
|
||||||
|
need_to_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
dcp_world_size: int
|
||||||
|
dcp_rank: int
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# use __new__ so that all subclasses will call this
|
||||||
|
self = super().__new__(cls)
|
||||||
|
try:
|
||||||
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.dcp_world_size = 1
|
||||||
|
self.dcp_rank = 0
|
||||||
|
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
|
||||||
|
and self.can_return_lse_for_decode
|
||||||
|
return self
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
kv_cache_dtype: str = "auto",
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: T,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||||
|
"""
|
||||||
|
Does this attention implementation support fused output quantization.
|
||||||
|
This is used by the AttnFusionPass to only fuse output quantization
|
||||||
|
onto implementations that support it.
|
||||||
|
|
||||||
|
:param quant_key: QuantKey object that describes the quantization op
|
||||||
|
:return: is fusion supported for this type of quantization
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
hidden_states_or_cq: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: T,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
|
||||||
|
return kv_cache_dtype != "auto"
|
||||||
33
vllm/attention/backends/utils.py
Normal file
33
vllm/attention/backends/utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Attention backend utils"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLADims:
|
||||||
|
q_lora_rank: Optional[int]
|
||||||
|
kv_lora_rank: int
|
||||||
|
qk_nope_head_dim: int
|
||||||
|
qk_rope_head_dim: int
|
||||||
|
v_head_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
def get_mla_dims(model_config: ModelConfig) -> MLADims:
|
||||||
|
hf_text_config = model_config.hf_text_config
|
||||||
|
|
||||||
|
return MLADims(
|
||||||
|
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
|
||||||
|
kv_lora_rank=hf_text_config.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
|
||||||
|
v_head_dim=hf_text_config.v_head_dim,
|
||||||
|
)
|
||||||
645
vllm/attention/layer.py
Normal file
645
vllm/attention/layer.py
Normal file
@@ -0,0 +1,645 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Attention layer."""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||||
|
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||||
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group,
|
||||||
|
is_v1_kv_transfer_group)
|
||||||
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape)
|
||||||
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||||
|
from vllm.platforms import _Backend, current_platform
|
||||||
|
from vllm.utils import GiB_bytes, direct_register_custom_op
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
USE_XFORMERS_OPS = None
|
||||||
|
try:
|
||||||
|
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
|
||||||
|
except AttributeError:
|
||||||
|
tag_cudagraph_unsafe = () # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
def check_xformers_availability():
|
||||||
|
global USE_XFORMERS_OPS
|
||||||
|
if USE_XFORMERS_OPS is not None:
|
||||||
|
return USE_XFORMERS_OPS
|
||||||
|
|
||||||
|
if current_platform.is_cuda() and current_platform.has_device_capability(
|
||||||
|
100):
|
||||||
|
# Xformers FA is not compatible with B200
|
||||||
|
USE_XFORMERS_OPS = False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
find_spec("xformers.ops")
|
||||||
|
USE_XFORMERS_OPS = True
|
||||||
|
except ImportError:
|
||||||
|
USE_XFORMERS_OPS = False
|
||||||
|
|
||||||
|
# the warning only needs to be shown once
|
||||||
|
if not USE_XFORMERS_OPS:
|
||||||
|
logger.warning("Xformers is not available, falling back.")
|
||||||
|
|
||||||
|
return USE_XFORMERS_OPS
|
||||||
|
|
||||||
|
|
||||||
|
def check_upstream_fa_availability(dtype: torch.dtype):
|
||||||
|
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
|
||||||
|
) and current_platform.has_device_capability(80):
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
return is_flash_attn_2_available()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module, AttentionLayerBase):
|
||||||
|
"""Attention layer.
|
||||||
|
|
||||||
|
This class takes query, key, and value tensors as input. The input tensors
|
||||||
|
can either contain prompt tokens or generation tokens.
|
||||||
|
The class does the following:
|
||||||
|
|
||||||
|
1. Store the input key and value tensors in the KV cache.
|
||||||
|
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||||
|
3. Return the output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
per_layer_sliding_window: Optional[int] = None,
|
||||||
|
use_mla: bool = False,
|
||||||
|
use_sparse: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||||
|
**extra_impl_args,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
The KV cache is stored inside this class and is accessed via
|
||||||
|
`self.kv_cache`.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if per_layer_sliding_window is not None:
|
||||||
|
# per-layer sliding window
|
||||||
|
sliding_window = per_layer_sliding_window
|
||||||
|
elif cache_config is not None:
|
||||||
|
# model-level sliding window
|
||||||
|
sliding_window = cache_config.sliding_window
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
calculate_kv_scales = False
|
||||||
|
if num_kv_heads is None:
|
||||||
|
num_kv_heads = num_heads
|
||||||
|
assert num_heads % num_kv_heads == 0, \
|
||||||
|
f"num_heads ({num_heads}) is not " \
|
||||||
|
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||||
|
|
||||||
|
# The default k/v_scale is set to 1.0. This is ignored
|
||||||
|
# when kv-cache is not fp8, and should be used with
|
||||||
|
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||||
|
# expect the pre-quantized k/v_scale to be loaded along
|
||||||
|
# with the model weights.
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.calculate_kv_scales = calculate_kv_scales
|
||||||
|
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
# FlashAttn doesn't support quantizing the kv-cache only
|
||||||
|
# but requires q to be quantized as well.
|
||||||
|
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
|
||||||
|
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||||
|
# backends that require the scales to be on host instead of on device.
|
||||||
|
# e.g. Flashinfer
|
||||||
|
self._q_scale_float = 1.0
|
||||||
|
self._k_scale_float = 1.0
|
||||||
|
self._v_scale_float = 1.0
|
||||||
|
|
||||||
|
# The output scale on host memory. This should be the input scale of
|
||||||
|
# the quant op after this attention layer.
|
||||||
|
self._o_scale_float: Optional[float] = None
|
||||||
|
|
||||||
|
self.use_mla = use_mla
|
||||||
|
self.use_sparse = use_sparse
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||||
|
|
||||||
|
quant_method = quant_config.get_quant_method(
|
||||||
|
self, prefix=prefix) if quant_config else None
|
||||||
|
if quant_method is not None and not isinstance(
|
||||||
|
quant_method, UnquantizedLinearMethod):
|
||||||
|
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||||
|
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||||
|
# checkpoint config and become the "auto" behavior
|
||||||
|
if self.kv_cache_dtype == "fp8_e5m2":
|
||||||
|
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||||
|
"fp8 checkpoints.")
|
||||||
|
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||||
|
# parameters so that it can be loaded from the model checkpoint.
|
||||||
|
# The k/v_scale will then be converted back to native float32
|
||||||
|
# values after weight loading.
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.quant_method.create_weights(self)
|
||||||
|
|
||||||
|
# During model initialization, the default dtype is set as the model
|
||||||
|
# weight and activation dtype.
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
if attn_backend is None:
|
||||||
|
self.attn_backend = get_attn_backend(head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
use_mla=use_mla,
|
||||||
|
has_sink=self.has_sink,
|
||||||
|
use_sparse=use_sparse)
|
||||||
|
else:
|
||||||
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
|
impl_cls = self.attn_backend.get_impl_cls()
|
||||||
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
|
logits_soft_cap, attn_type,
|
||||||
|
kv_sharing_target_layer_name, **extra_impl_args)
|
||||||
|
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||||
|
# torch.compile works by registering the attention as one giant
|
||||||
|
# opaque custom op. For other platforms, we directly call them
|
||||||
|
# and let torch.compile handle them.
|
||||||
|
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||||
|
|
||||||
|
self.use_output = self.attn_backend.accept_output_buffer
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
self.layer_name = prefix
|
||||||
|
self.attn_type = attn_type
|
||||||
|
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
validate_kv_sharing_target(
|
||||||
|
prefix,
|
||||||
|
kv_sharing_target_layer_name,
|
||||||
|
compilation_config.static_forward_context,
|
||||||
|
)
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
|
# use a placeholder kv cache tensor during init, which will be replaced
|
||||||
|
# by bind_kv_cache
|
||||||
|
# this variable will not be accessed if use_direct_call is True
|
||||||
|
self.kv_cache = [
|
||||||
|
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||||
|
).parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
|
||||||
|
dtype=torch.float32)
|
||||||
|
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
|
||||||
|
dtype=torch.float32)
|
||||||
|
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
|
||||||
|
dtype=torch.float32)
|
||||||
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to initialize attention q/k/v range constants: %s", e)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.debug("CUDA device: %s", torch.cuda.current_device())
|
||||||
|
logger.debug("Allocated: %.2f GiB",
|
||||||
|
torch.cuda.memory_allocated() / GiB_bytes)
|
||||||
|
logger.debug("Reserved: %.2f GiB",
|
||||||
|
torch.cuda.memory_reserved() / GiB_bytes)
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to initialize q/k/v range constants. "
|
||||||
|
"This may be caused by insufficient memory to allocate "
|
||||||
|
"kv cache.") from e
|
||||||
|
|
||||||
|
# for attn backends supporting query quantization
|
||||||
|
self.query_quant = None
|
||||||
|
if self.kv_cache_dtype.startswith(
|
||||||
|
"fp8") and self.attn_backend.supports_quant_query_input:
|
||||||
|
self.query_quant = QuantFP8(static=True,
|
||||||
|
group_shape=GroupShape.PER_TENSOR)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
# For some alternate attention backends like MLA the attention output
|
||||||
|
# shape does not match the query shape, so we optionally let the model
|
||||||
|
# definition specify the output tensor shape.
|
||||||
|
output_shape: Optional[torch.Size] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
The KV cache is stored inside this class and is accessed via
|
||||||
|
`self.kv_cache`.
|
||||||
|
|
||||||
|
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||||
|
the model runner's `execute_model` method. It is accessed via forward
|
||||||
|
context using
|
||||||
|
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||||
|
"""
|
||||||
|
if self.calculate_kv_scales:
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
if attn_metadata.enable_kv_scales_calculation:
|
||||||
|
self.calc_kv_scales(query, key, value)
|
||||||
|
|
||||||
|
output_dtype = query.dtype
|
||||||
|
if self.query_quant is not None:
|
||||||
|
# quantizing with a simple torch operation enables
|
||||||
|
# torch.compile to fuse this into previous ops
|
||||||
|
# which reduces overheads during decoding.
|
||||||
|
# Otherwise queries are quantized using custom ops
|
||||||
|
# which causes decoding overheads
|
||||||
|
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||||
|
query, _ = self.query_quant(query, self._q_scale)
|
||||||
|
|
||||||
|
if self.use_output:
|
||||||
|
output_shape = (output_shape
|
||||||
|
if output_shape is not None else query.shape)
|
||||||
|
output = torch.zeros(output_shape,
|
||||||
|
dtype=output_dtype,
|
||||||
|
device=query.device)
|
||||||
|
hidden_size = output_shape[-1]
|
||||||
|
# We skip reshaping query, key and value tensors for the MLA
|
||||||
|
# backend since these tensors have different semantics and are
|
||||||
|
# processed differently.
|
||||||
|
if not self.use_mla:
|
||||||
|
# Reshape the query, key, and value tensors.
|
||||||
|
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||||
|
# CPU overheads from the non-CUDA-graph regions.
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
output = output.view(-1, self.num_heads, self.head_size)
|
||||||
|
if key is not None:
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
if value is not None:
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
if self.use_direct_call:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
self.impl.forward(self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self_kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
output=output)
|
||||||
|
else:
|
||||||
|
torch.ops.vllm.unified_attention_with_output(
|
||||||
|
query, key, value, output, self.layer_name)
|
||||||
|
return output.view(-1, hidden_size)
|
||||||
|
else:
|
||||||
|
if self.use_direct_call:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[self.layer_name]
|
||||||
|
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
return self.impl.forward(self, query, key, value,
|
||||||
|
self_kv_cache, attn_metadata)
|
||||||
|
else:
|
||||||
|
return torch.ops.vllm.unified_attention(
|
||||||
|
query, key, value, self.layer_name)
|
||||||
|
|
||||||
|
def calc_kv_scales(self, query, key, value):
|
||||||
|
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||||
|
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||||
|
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||||
|
self._q_scale_float = self._q_scale.item()
|
||||||
|
self._k_scale_float = self._k_scale.item()
|
||||||
|
self._v_scale_float = self._v_scale.item()
|
||||||
|
# We only calculate the scales once
|
||||||
|
self.calculate_kv_scales = False
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||||
|
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||||
|
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
||||||
|
s += f", scale={self.impl.scale}" # type: ignore
|
||||||
|
s += f", backend={self.impl.__class__.__name__}"
|
||||||
|
return s
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
if hasattr(self.impl, "process_weights_after_loading"):
|
||||||
|
self.impl.process_weights_after_loading(act_dtype)
|
||||||
|
|
||||||
|
# FlashInfer requires attention sinks to be float32
|
||||||
|
if (self.backend == _Backend.FLASHINFER
|
||||||
|
and hasattr(self.impl, 'sinks')):
|
||||||
|
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
|
||||||
|
assert isinstance(self.impl, FlashInferImpl)
|
||||||
|
if (self.impl.sinks is not None
|
||||||
|
and self.impl.sinks.dtype != torch.float32):
|
||||||
|
self.impl.sinks = self.impl.sinks.to(torch.float32)
|
||||||
|
|
||||||
|
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||||
|
return self.attn_backend
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
"""Multi-headed attention without any cache, used for ViT."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = scale
|
||||||
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0, \
|
||||||
|
f"num_heads ({self.num_heads}) is not " \
|
||||||
|
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
# During model initialization, the default dtype is set as the model
|
||||||
|
# weight and activation dtype.
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
# Determine the attention backend
|
||||||
|
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
|
||||||
|
|
||||||
|
# Some auto-selected backends can be upgraded
|
||||||
|
# to upstream flash attention if available.
|
||||||
|
# If vllm native fa is selected, we use it directly.
|
||||||
|
use_upstream_fa = False
|
||||||
|
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||||
|
dtype):
|
||||||
|
backend = _Backend.FLASH_ATTN
|
||||||
|
use_upstream_fa = True
|
||||||
|
|
||||||
|
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||||
|
# currently, only torch_sdpa is supported on rocm/xpu
|
||||||
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
else:
|
||||||
|
|
||||||
|
self.attn_backend = backend if backend in {
|
||||||
|
_Backend.TORCH_SDPA,
|
||||||
|
_Backend.XFORMERS,
|
||||||
|
_Backend.PALLAS,
|
||||||
|
_Backend.ROCM_AITER_FA,
|
||||||
|
_Backend.FLASH_ATTN,
|
||||||
|
} else _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
if (self.attn_backend == _Backend.XFORMERS
|
||||||
|
and not check_xformers_availability()):
|
||||||
|
self.attn_backend = _Backend.TORCH_SDPA
|
||||||
|
|
||||||
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
if use_upstream_fa:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||||
|
else:
|
||||||
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
self._flash_attn_varlen_func = flash_attn_varlen_func
|
||||||
|
|
||||||
|
logger.info_once(
|
||||||
|
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
|
||||||
|
f"use_upstream_fa: {use_upstream_fa}")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Input shape:
|
||||||
|
(batch_size x seq_len x hidden_size) or
|
||||||
|
(batch_size x seq_len x num_heads x head_size)
|
||||||
|
"""
|
||||||
|
bsz, q_len = query.size()[:2]
|
||||||
|
kv_len = key.size(1)
|
||||||
|
|
||||||
|
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||||
|
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||||
|
# Handle MQA and GQA
|
||||||
|
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||||
|
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||||
|
|
||||||
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||||
|
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
|
||||||
|
step=q_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query.device)
|
||||||
|
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
|
||||||
|
step=kv_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=key.device)
|
||||||
|
|
||||||
|
out = self._flash_attn_varlen_func(
|
||||||
|
query.flatten(0, 1),
|
||||||
|
key.flatten(0, 1),
|
||||||
|
value.flatten(0, 1),
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=q_len,
|
||||||
|
max_seqlen_k=kv_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
)
|
||||||
|
elif self.attn_backend == _Backend.XFORMERS:
|
||||||
|
from xformers import ops as xops
|
||||||
|
|
||||||
|
out = xops.memory_efficient_attention_forward(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
scale=self.scale)
|
||||||
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||||
|
query, key, value = (x.transpose(1, 2)
|
||||||
|
for x in (query, key, value))
|
||||||
|
out = F.scaled_dot_product_attention(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
scale=self.scale)
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
elif self.attn_backend == _Backend.PALLAS:
|
||||||
|
query, key, value = (x.transpose(1, 2)
|
||||||
|
for x in (query, key, value))
|
||||||
|
from torch_xla.experimental.custom_kernel import flash_attention
|
||||||
|
out = flash_attention(query, key, value, sm_scale=self.scale)
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
elif self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||||
|
from aiter import flash_attn_varlen_func
|
||||||
|
|
||||||
|
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
|
||||||
|
out = flash_attn_varlen_func(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
softmax_scale=self.scale)
|
||||||
|
else:
|
||||||
|
# ViT attention hasn't supported this backend yet
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"ViT attention hasn't supported {self.attn_backend} "
|
||||||
|
f"backend yet.")
|
||||||
|
|
||||||
|
return out.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||||
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
return
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
return
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
connector.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_save_kv_layer_to_connector(
|
||||||
|
layer_name: str,
|
||||||
|
kv_cache_layer: List[torch.Tensor],
|
||||||
|
):
|
||||||
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
return
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
return
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
connector.save_kv_layer(layer_name, kv_cache_layer,
|
||||||
|
attn_metadata[layer_name])
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
output = self.impl.forward(self, query, key, value, kv_cache,
|
||||||
|
attn_metadata)
|
||||||
|
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_fake(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(query).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_attention",
|
||||||
|
op_func=unified_attention,
|
||||||
|
fake_impl=unified_attention_fake,
|
||||||
|
tags=tag_cudagraph_unsafe,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_with_output(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
|
self.impl.forward(self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
output=output,
|
||||||
|
output_scale=output_scale,
|
||||||
|
output_block_scale=output_block_scale)
|
||||||
|
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_with_output_fake(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_attention_with_output",
|
||||||
|
op_func=unified_attention_with_output,
|
||||||
|
mutates_args=["output", "output_block_scale"],
|
||||||
|
fake_impl=unified_attention_with_output_fake,
|
||||||
|
tags=tag_cudagraph_unsafe,
|
||||||
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user