This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

124
torch_vacc/__init__.py Normal file
View File

@@ -0,0 +1,124 @@
import atexit
import ctypes
import os
import sys
import types
import torch
import torch.distributed
from .version import __version__
def register_runtime_libraries() -> None:
try:
libpython_so = f"libpython{sys.version_info.major}.{sys.version_info.minor}.so"
base_prefix = getattr(sys, "base_prefix", sys.prefix)
if not base_prefix.startswith("/usr"): # like conda or virtualenv
ctypes.CDLL(os.path.join(base_prefix, "lib", libpython_so))
this_path = os.path.dirname(os.path.realpath(__file__))
rt_dll_dpath = os.path.join(this_path, "_vacc_libs")
ctypes.CDLL(os.path.join(rt_dll_dpath, "libodsp.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvaccrt.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvnnl.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvccl.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvacc_core.so"))
except Exception as e:
raise RuntimeError("Vastai runtime library not loaded.") from e
register_runtime_libraries()
from ._vacc_libs import _torch_vacc as _C
try:
_C._init_torch_vacc_module()
except Exception as e:
raise RuntimeError("Failed to init torch_vacc.") from e
def _apply_patches(monkey_patches):
def _getattr(module_list, root_module=torch):
if len(module_list) <= 1:
return root_module
if hasattr(root_module, module_list[0]):
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
else:
empty_module_name = f"{root_module.__name__}.{module_list[0]}"
sys.modules[empty_module_name] = types.ModuleType(empty_module_name)
setattr(root_module, module_list[0], sys.modules.get(empty_module_name))
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
for patch_pair in monkey_patches:
dest, patch = patch_pair
dest_module = _getattr(dest.split("."), root_module=torch)
last_module_level = dest.split(".")[-1]
if not isinstance(patch, types.ModuleType):
setattr(dest_module, last_module_level, patch)
continue
if not hasattr(dest_module, last_module_level) or not hasattr(patch, "__all__"):
setattr(dest_module, last_module_level, patch)
sys.modules[f"{dest_module.__name__}.{last_module_level}"] = patch
continue
assert hasattr(patch, "__all__"), "Patch module must have __all__ definition."
dest_module = getattr(dest_module, last_module_level)
for attr in patch.__all__:
setattr(dest_module, attr, getattr(patch, attr))
import torch_vacc.vacc as vacc
# register "vacc" module/functions to torch
torch._register_device_module("vacc", vacc)
unsupported_dtype = [
torch.quint8,
torch.quint4x2,
torch.quint2x4,
torch.qint32,
torch.qint8,
]
torch.utils.generate_methods_for_privateuse1_backend(
for_tensor=True,
for_module=True,
for_storage=True, # TODO(qingsong): do we support storage?
unsupported_dtype=unsupported_dtype,
)
# register legacy *DtypeTensor into torch.vacc
_C._initialize_python_bindings()
# init seed generators, vacc default generator
vacc.init()
def is_vccl_available() -> bool:
return True
torch.distributed.is_vccl_available = is_vccl_available
def set_global_log_level(log_level):
_C.set_global_log_level(log_level.upper())
def print_vacc_ops():
_C._print_vacc_ops()
def vacc_ops_list():
return _C._vacc_ops_list().split(",")
def print_vacc_selective_ops():
_C._print_vacc_selective_ops()
def _vacc_shutdown():
_C._vacc_module_shutdown()
atexit.register(_vacc_shutdown)

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
torch_vacc/_vacc_libs/libodsp.so Executable file

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
torch_vacc/_vacc_libs/libvccl.so Executable file

Binary file not shown.

BIN
torch_vacc/_vacc_libs/libvnnl.so Executable file

Binary file not shown.

View File

View File

@@ -0,0 +1,398 @@
import os
import warnings
import logging as logger
from functools import wraps
import torch
import torch_vacc
'''
try:
import torchair
except ImportError:
IS_TORCHAIR_INSTALLED = False
else:
IS_TORCHAIR_INSTALLED = True
'''
warnings.filterwarnings(action="once")
torch_fn_white_list = [
"_cudnn_init_dropout_state",
"_empty_affine_quantized",
"_empty_per_channel_affine_quantized",
"_pin_memory",
"_sparse_coo_tensor_unsafe",
"_sparse_csr_tensor_unsafe",
"logspace",
"randint",
"hann_window",
"rand",
"full_like",
"ones_like",
"rand_like",
"randperm",
"arange",
"frombuffer",
"normal",
"empty_strided",
"empty_like",
"scalar_tensor",
"tril_indices",
"bartlett_window",
"ones",
"sparse_coo_tensor",
"randn",
"kaiser_window",
"tensor",
"triu_indices",
"as_tensor",
"zeros",
"randint_like",
"full",
"eye",
"empty",
"blackman_window",
"zeros_like",
"range",
"sparse_csr_tensor",
"randn_like",
"from_file",
"linspace",
"hamming_window",
"empty_quantized",
"autocast",
"load",
]
torch_tensor_fn_white_list = [
"new_empty",
"new_empty_strided",
"new_full",
"new_ones",
"new_tensor",
"new_zeros",
"to",
]
torch_module_fn_white_list = ["to", "to_empty"]
torch_cuda_fn_white_list = [
"get_device_properties",
"get_device_name",
"get_device_capability",
"list_gpu_processes",
"set_device",
"synchronize",
"mem_get_info",
"memory_stats",
"memory_summary",
"memory_allocated",
"max_memory_allocated",
"reset_max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"reset_max_memory_cached",
"reset_peak_memory_stats",
"current_stream",
"default_stream",
]
torch_profiler_fn_white_list = ["profile"]
torch_distributed_fn_white_list = ["__init__"]
device_kwargs_list = ["device", "device_type", "map_location"]
def wrapper_cuda(fn):
@wraps(fn)
def decorated(*args, **kwargs):
replace_int = fn.__name__ in ["to", "to_empty"]
if args:
args_new = list(args)
args = replace_cuda_to_vacc_in_list(args_new, replace_int)
if kwargs:
for device_arg in device_kwargs_list:
device = kwargs.get(device_arg, None)
if device is not None:
replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device)
device_ids = kwargs.get("device_ids", None)
if type(device_ids) == list:
device_ids = replace_cuda_to_vacc_in_list(device_ids, replace_int)
return fn(*args, **kwargs)
return decorated
def replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device):
if type(device) == str and "cuda" in device:
kwargs[device_arg] = device.replace("cuda", "vacc")
elif type(device) == torch.device and "cuda" in device.type:
device_info = (
"vacc:{}".format(device.index) if device.index is not None else "vacc"
)
kwargs[device_arg] = torch.device(device_info)
elif type(device) == int:
kwargs[device_arg] = f"vacc:{device}"
elif type(device) == dict:
kwargs[device_arg] = replace_cuda_to_vacc_in_dict(device)
def replace_cuda_to_vacc_in_list(args_list, replace_int):
for idx, arg in enumerate(args_list):
if isinstance(arg, str) and "cuda" in arg:
args_list[idx] = arg.replace("cuda", "vacc")
elif isinstance(arg, torch.device) and "cuda" in arg.type:
device_info = (
"vacc:{}".format(arg.index) if arg.index is not None else "vacc"
)
args_list[idx] = torch.device(device_info)
elif replace_int and not isinstance(arg, bool) and isinstance(arg, int):
args_list[idx] = f"vacc:{arg}"
elif isinstance(arg, dict):
args_list[idx] = replace_cuda_to_vacc_in_dict(arg)
return args_list
def replace_cuda_to_vacc_in_dict(device_dict):
new_dict = {}
for key, value in device_dict.items():
if isinstance(key, str):
key = key.replace("cuda", "vacc")
if isinstance(value, str):
value = value.replace("cuda", "vacc")
new_dict[key] = value
return new_dict
def device_wrapper(enter_fn, white_list):
for fn_name in white_list:
fn = getattr(enter_fn, fn_name, None)
if fn:
setattr(enter_fn, fn_name, wrapper_cuda(fn))
def wrapper_vccl(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if args:
args_new = list(args)
for idx, arg in enumerate(args_new):
if type(arg) == str and "nccl" in arg:
args_new[idx] = arg.replace("nccl", "vccl")
args = args_new
if kwargs:
if type(kwargs.get("backend", None)) == str:
kwargs["backend"] = "vccl"
return fn(*args, **kwargs)
return decorated
def wrapper_data_loader(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if kwargs:
pin_memory = kwargs.get("pin_memory", False)
pin_memory_device = kwargs.get("pin_memory_device", None)
if pin_memory and not pin_memory_device:
kwargs["pin_memory_device"] = "vacc"
if (
pin_memory
and type(pin_memory_device) == str
and "cuda" in pin_memory_device
):
kwargs["pin_memory_device"] = pin_memory_device.replace("cuda", "vacc")
return fn(*args, **kwargs)
return decorated
def wrapper_get_available_device_type(fn):
@wraps(fn)
def decorated(*args, **kwargs):
try:
if (torch.vacc.is_available()):
return 'vacc'
except Exception as e:
msg = "vacc device is not available."
warnings.warn(msg, RuntimeWarning)
return fn(*args, **kwargs)
return decorated
'''
def wrapper_profiler(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if kwargs:
if (
"experimental_config" in kwargs.keys()
and type(kwargs.get("experimental_config"))
!= torch_vacc.profiler._ExperimentalConfig
):
logger.warning(
"The parameter experimental_config of torch.profiler.profile has been deleted by the tool "
"because it can only be used in cuda, please manually modify the code "
"and use the experimental_config parameter adapted to vacc."
)
del kwargs["experimental_config"]
return fn(*args, **kwargs)
return decorated
def wrapper_compile(fn):
@wraps(fn)
def decorated(*args, **kwargs):
vacc_backend = torchair.get_vacc_backend()
if kwargs:
backend = kwargs.get("backend", None)
if (
not backend
or not isinstance(backend, functools.partial)
or not isinstance(backend.func, type(vacc_backend.func))
):
kwargs["backend"] = vacc_backend
else:
kwargs["backend"] = vacc_backend
return fn(*args, **kwargs)
return decorated
'''
def jit_script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None):
msg = "torch.jit.script will be disabled by transfer_to_vacc, which currently does not support it."
warnings.warn(msg, RuntimeWarning)
return obj
def patch_cuda():
patchs = [
["cuda", torch_vacc.vacc],
["cuda.amp", torch_vacc.vacc.amp],
["cuda.random", torch_vacc.vacc.random],
["cuda.amp.autocast_mode", torch_vacc.vacc.amp.autocast_mode],
["cuda.amp.common", torch_vacc.vacc.amp.common],
["cuda.amp.grad_scaler", torch_vacc.vacc.amp.grad_scaler],
]
torch_vacc._apply_patches(patchs)
'''
def patch_profiler():
patchs = [
["profiler.profile", torch_vacc.profiler.profile],
["profiler.schedule", torch_vacc.profiler.schedule],
[
"profiler.tensorboard_trace_handler",
torch_vacc.profiler.tensorboard_trace_handler,
],
["profiler.ProfilerAction", torch_vacc.profiler.ProfilerAction],
["profiler.ProfilerActivity.CUDA", torch_vacc.profiler.ProfilerActivity.VACC],
["profiler.ProfilerActivity.CPU", torch_vacc.profiler.ProfilerActivity.CPU],
]
torch_vacc._apply_patches(patchs)
'''
def warning_fn(msg, rank0=True):
is_distributed = (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
)
env_rank = os.getenv("RANK", None)
if rank0 and is_distributed:
if torch.distributed.get_rank() == 0:
warnings.warn(msg, ImportWarning)
elif rank0 and env_rank:
if env_rank == "0":
warnings.warn(msg, ImportWarning)
else:
warnings.warn(msg, ImportWarning)
def init():
warning_fn(
"""
*************************************************************************************************************
The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.vacc and torch.nn.Module.vacc now..
The torch.cuda.DoubleTensor is replaced with torch.vacc.FloatTensor cause the double type is not supported now..
The backend in torch.distributed.init_process_group set to vccl now..
The torch.cuda.* and torch.cuda.amp.* are replaced with torch.vacc.* and torch.vacc.amp.* now..
The device parameters have been replaced with vacc in the function below:
{}
If you notices any functions you use is not included in the above list, feel free to contact torch-vacc development team.
*************************************************************************************************************
""".format(
", ".join(
["torch." + i for i in torch_fn_white_list]
+ ["torch.Tensor." + i for i in torch_tensor_fn_white_list]
+ ["torch.nn.Module." + i for i in torch_module_fn_white_list]
)
)
)
# torch.cuda.*
patch_cuda()
device_wrapper(torch.cuda, torch_cuda_fn_white_list)
# torch.profiler.*
# TODO(qingsong): profiler not implemented yet
# patch_profiler()
# device_wrapper(torch.profiler, torch_profiler_fn_white_list)
# torch.*
device_wrapper(torch, torch_fn_white_list)
# torch.Tensor.*
device_wrapper(torch.Tensor, torch_tensor_fn_white_list)
torch.Tensor.cuda = torch.Tensor.vacc
torch.Tensor.is_cuda = torch.Tensor.is_vacc
for dtype_tensor in [
"ByteTensor",
"CharTensor",
"DoubleTensor",
"FloatTensor",
"IntTensor",
"LongTensor",
"ShortTensor",
"HalfTensor",
"BoolTensor",
]:
setattr(
torch.cuda,
dtype_tensor,
getattr(torch.vacc, dtype_tensor),
)
# TODO(qingsong): do we need this? should we add LongTensor=IntTensor?
torch.cuda.DoubleTensor = torch.vacc.FloatTensor
# torch.nn.Module.*
device_wrapper(torch.nn.Module, torch_module_fn_white_list)
torch.nn.Module.cuda = torch.nn.Module.vacc
# torch.distributed.init_process_group
torch.distributed.init_process_group = wrapper_vccl(
torch.distributed.init_process_group
)
torch.distributed.is_nccl_available = torch.distributed.is_vccl_available
# torch.nn.parallel.DistributedDataParallel
device_wrapper(
torch.nn.parallel.DistributedDataParallel, torch_distributed_fn_white_list
)
# torch.utils.data.DataLoader
torch.utils.data.DataLoader.__init__ = wrapper_data_loader(
torch.utils.data.DataLoader.__init__
)
torch.jit.script = jit_script
torch._utils._get_available_device_type = wrapper_get_available_device_type(
torch._utils._get_available_device_type
)
'''
if IS_TORCHAIR_INSTALLED:
torch.compile = wrapper_compile(torch.compile)
'''
init()

View File

View File

@@ -0,0 +1,42 @@
import torch
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
output, rsigma, var = torch.ops.vacc.rms_norm_forward(input, weight, eps)
ctx.save_for_backward(input, weight, rsigma, var)
ctx.eps = eps
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
input, weight, rsigma, var = ctx.saved_tensors
grad_input, grad_weight = _torch_vacc.rms_norm_backward(
grad_output, input, weight, rsigma, var, ctx.eps
)
return grad_input, grad_weight, None
def rms_norm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
return FusedRMSNormFunction.apply(input, weight, eps)
class FusedRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6):
super(FusedRMSNorm, self).__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
output = FusedRMSNormFunction.apply(hidden_states, self.weight, self.eps)
output = output.to(dtype)
return output

View File

@@ -0,0 +1,32 @@
import torch
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
class FusedRopeEmbFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q: torch.Tensor, k: torch.Tensor, offset: int):
qemb, kemb = _torch_vacc.rope_forward(q, k, offset)
ctx.offset = offset
return qemb, kemb
@staticmethod
def backward(ctx, q_out_grad: torch.Tensor, k_out_grad: torch.Tensor):
grad_input, grad_rope = _torch_vacc.rope_backward(
q_out_grad, k_out_grad, ctx.offset
)
return grad_input, grad_rope, None
def rope_emb(q: torch.Tensor, k: torch.Tensor, offset: int):
# return FusedRopeEmbFunction.apply(q, k, offset)
return torch_vacc.vacc.custom_ops.RotaryPosEmbedding(q=q, k=k, offset=offset)
class RopeEmb(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int):
return rope_emb(q, k, offset)

View File

@@ -0,0 +1,62 @@
from contextlib import contextmanager
import os
import sys
import torch
from torch.testing import make_tensor
from functools import partial, wraps
import torch.testing._internal.common_device_type as cdt
from torch.testing._internal.common_device_type import (
DeviceTypeTestBase,
dtypes,
instantiate_device_type_tests,
onlyOn,
onlyPRIVATEUSE1,
ops,
)
if sys.version_info > (3, 8):
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
init_multigpu_helper,
skip_if_lt_x_gpu,
get_timeout,
#skip_if_rocm,
with_dist_debug_levels,
)
else:
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
init_multigpu_helper,
skip_if_lt_x_gpu,
get_timeout,
skip_if_rocm,
with_dist_debug_levels,
)
from torch.testing._internal.common_utils import (
TestCase,
load_tests,
parametrize,
run_tests,
subtest,
retry_on_connect_failures,
instantiate_parametrized_tests,
)
onlyVacc = onlyPRIVATEUSE1
class VaccTestBase(DeviceTypeTestBase):
device_type = "vacc"
if VaccTestBase not in cdt.device_type_test_bases:
cdt.device_type_test_bases.append(VaccTestBase)
@contextmanager
def freeze_rng_state():
rng_state = torch.get_rng_state()
yield
torch.set_rng_state(rng_state)

View File

@@ -0,0 +1,103 @@
"""
Tool to summarize unit test XML reports, it summarize
* number of tests, and failure/error/skipped
* top 10 slowest tests
Usage:
python -m torch_vacc.testing.summarize_report --report report.xml
"""
import argparse
from dataclasses import dataclass
from xml.etree import ElementTree as ET
import sys
from torch_vacc import set_global_log_level
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--report", type=str)
return parser.parse_args()
def summarize_testsuites(suites):
summary = {
"errors": int,
"failures": int,
"skipped": int,
"skips": int,
"tests": int,
"time": float,
}
attribs = [s.attrib for s in suites]
for key in summary:
summary[key] = sum(summary[key](a[key]) for a in attribs if key in a)
assert not (summary["skipped"] and summary["skips"])
if summary["skips"]:
summary["skipped"] = summary["skips"]
return summary
def format_summary(summary):
template = "Ran {tests} tests in {time:.3f}s (errors={errors}, failures={failures}, skipped={skipped})"
msg = template.format(**summary)
if summary["errors"] > 0 or summary["failures"] > 0:
msg = "FAILED. " + msg
return msg
@dataclass
class TestCaseInfo:
test_class_name: str
test_name: str
time: float
timestamp: str
success: bool
def __lt__(self, other):
return self.time < other.time
def sort_cases_by_time(suites):
test_cases = [
TestCaseInfo(
s.attrib["classname"],
s.attrib["name"],
s.attrib["time"],
s.attrib["timestamp"],
s.attrib.get("failure") is None,
)
for s in suites
]
test_cases.sort(reverse=True)
return test_cases
def read_report(fpath):
with open(fpath) as report:
try:
report = ET.parse(report)
except ET.ParseError:
print(f"{sys.argv[0]}: Cannot parse file {fpath}", file=sys.stderr)
return
root = report.getroot()
suites = [root] if root.tag == "testsuite" else root.findall("testsuite")
summary = summarize_testsuites(suites)
summary_msg = format_summary(summary)
print(summary_msg)
for suite in suites:
cases = sort_cases_by_time(suite.findall("testcase"))
[print(case) for case in cases[:10]]
def main():
set_global_log_level("ERROR")
args = parse_args()
read_report(args.report)
if __name__ == "__main__":
main()

184
torch_vacc/vacc/__init__.py Normal file
View File

@@ -0,0 +1,184 @@
from __future__ import annotations
from typing import Tuple
import torch
from ._device import (
current_device,
device,
device_count,
get_device_capability,
get_device_name,
get_device_properties,
is_available,
is_bf16_supported,
set_device,
synchronize,
)
from .amp import (
get_amp_supported_dtype,
get_autocast_dtype,
is_autocast_enabled,
set_autocast_dtype,
set_autocast_enabled,
)
from .lazy_initialize import _is_in_bad_fork, _lazy_call, _lazy_init
from .memory import ( # caching_allocator_alloc,; caching_allocator_delete,
empty_cache,
get_allocator_backend,
max_memory_allocated,
max_memory_cached,
max_memory_reserved,
mem_get_info,
memory_allocated,
memory_cached,
memory_reserved,
memory_snapshot,
memory_stats,
memory_stats_as_nested_dict,
memory_summary,
reset_accumulated_memory_stats,
reset_max_memory_allocated,
reset_max_memory_cached,
reset_peak_memory_stats,
set_per_process_memory_fraction,
)
from .streams import Event, Stream, current_stream, default_stream, set_stream, stream
def init():
r"""Initialize PyTorch's VACC state. You may need to call
this explicitly if you are interacting with PyTorch via
its C API, as Python bindings for VACC functionality will not
be available until this initialization takes place. Ordinary users
should not need this, as all of PyTorch's VACC methods
automatically initialize VACC state on-demand.
Does nothing if the VACC state is already initialized.
"""
_lazy_init()
# default_generators is empty util _lazy_init() is called
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
from .custom_ops import *
from .custom_qwen3_ops import *
from .random import * # noqa: F403
__all__ = [
"device",
"is_available",
"is_bf16_supported",
"current_device",
"set_device",
"device_count",
"get_device_properties",
"get_device_name",
"get_device_capability",
"synchronize",
"amp",
"get_amp_supported_dtype",
"is_autocast_enabled",
"set_autocast_enabled",
"get_autocast_dtype",
"set_autocast_dtype",
"_is_in_bad_fork",
"_lazy_call",
"get_rng_state",
"get_rng_state_all",
"set_rng_state",
"set_rng_state_all",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"initial_seed",
"set_stream",
"current_stream",
"default_generators",
"default_stream",
"stream",
"Stream",
"Event",
"mem_get_info",
"set_per_process_memory_fraction",
"empty_cache",
"memory_stats",
"memory_stats_as_nested_dict",
"reset_accumulated_memory_stats",
"reset_peak_memory_stats",
"reset_max_memory_allocated",
"reset_max_memory_cached",
"memory_allocated",
"max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"memory_cached",
"max_memory_cached",
"memory_snapshot",
"memory_summary",
"get_allocator_backend",
"rms_norm",
"RotaryPosEmbedding",
"scaled_dot_product_attention",
"scaled_dot_product_attention_cp_forward",
"scaled_dot_product_attention_cp_backward",
"swiglu",
"paged_attention",
"reshape_and_cache_attention",
"concat_and_cache_attention",
"w8a8_block_fp8_matmul",
"moe_expert_token_group_reassign",
"fused_mlp_mm_fp8",
"fused_mlp_fp8",
"fused_moe_preprocess",
"fused_residual_rmsnorm",
"parallel_embedding",
"all_reduce",
"all_gather",
"broadcast",
"fused_mlp_moe_with_rmsnorm",
"fuse_moe_decode_v2_allreduce",
"topk_topp",
"fused_mla",
"fused_mla_allreduce",
"fused_mlp_with_rmsnorm",
"fused_mlp_allreduce",
"ds3_sampler",
"sampler_v1",
"rejection_sampler",
"rejection_sampler_update_hidden_states",
"rejection_sampler_v1",
"fused_matmul_allgather",
"fused_mla_v2",
"fused_mla_allreduce_v2",
"mla_matmul_scale",
"mla_matmul",
"fused_mla_prefill_stage0",
"fused_mla_prefill_stage1",
"fused_mla_prefill_stage0_allreduce",
"fuse_moe_prefill_stage0",
"fuse_mla_mlp_v2_allreduce_decode",
"fuse_mla_moe_v2_allreduce_decode",
"fuse_mla_mlp_v2_allreduce_decode_layers",
"fuse_mla_moe_v2_allreduce_decode_layers",
"fuse_mla_mlp_v2_allreduce_decode_layers_v2",
"fuse_mla_moe_v2_allreduce_decode_layers_v2",
"fuse_mlp_qwen_int4",
"fuse_mlp_qwen_int4_reduce",
"w4a8_block_int4_matmul",
"fuse_atten_qwen3",
"fuse_atten_qwen2",
"qwen3_fuse_attention_moe_decode",
"fuse_mtp_stage0",
"fuse_mtp_allreduce",
"roll_out",
"fused_experts_int4_prefill",
"fuse_bge_embedding_stage1",
"l2_norm",
"fuse_mlp_vision",
"patch_merger_vision",
"fuse_atten_vit",
"apply_penalties",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

106
torch_vacc/vacc/_device.py Normal file
View File

@@ -0,0 +1,106 @@
# Device information
# replacing `torch.cuda.func`` with `torch_vacc.vacc.func`.
# see https://pytorch.org/docs/stable/cuda.html
from typing import Any
import warnings
import torch
import torch_vacc
from torch._utils import _get_device_index
from torch_vacc._vacc_libs import _torch_vacc
from .lazy_initialize import _lazy_init
if hasattr(_torch_vacc, "_exchange_device"):
_exchange_device = _torch_vacc._exchange_device
else:
def _exchange_device(device: int) -> int:
return _torch_vacc._exchange_device()
if device < 0:
return -1
prev_device = current_device()
if device != prev_device:
set_device(device)
return prev_device
class device(object):
"""Context-manager that changes the selected device.
Args:
device (torch.device or int): device index to select. It's a no-op if
this argument is a negative integer or ``None``.
"""
def __init__(self, device: Any):
self.idx = _get_device_index(device, optional=True)
self.prev_idx = -1
def __enter__(self):
self.prev_idx = _exchange_device(self.idx)
def __exit__(self, *args):
_exchange_device(self.prev_idx)
return False
def is_available() -> bool:
r"""Returns whether vacc is available."""
return device_count() > 0
def is_bf16_supported() -> bool:
r"""Returns a bool indicating if the current vacc device supports dtype bfloat16"""
return True
def current_device() -> int:
r"""Returns the index of a currently selected vacc device."""
_lazy_init()
return _torch_vacc._current_device()
def set_device(device: torch.device):
device_index = _get_device_index(device, optional=True)
if device_index >= 0:
_torch_vacc._set_device(device_index)
def get_device_capability(device=None):
r"""Query the minor and major data of device. Cann does not
have a corresponding concept and is not supported. By default, it returns None
"""
_infos = "torch.vacc.get_device_capability isn't implemented! Please do the version check in other ways, Unlike CUDA major,min"
raise AssertionError(_infos)
def get_device_name(device_name=None):
device_id = _get_device_index(device_name, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id")
_lazy_init()
device_prop = _torch_vacc._vacc_getDeviceProperties(device_id)
return device_prop.name
def get_device_properties(device_name=None):
device_id = _get_device_index(device_name, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id")
_lazy_init()
return _torch_vacc._vacc_getDeviceProperties(device_id)
def device_count():
r"""Returns the number of available vacc devices"""
return _torch_vacc._device_count()
def synchronize(device=None) -> None:
"""Waits for all operations in all streams on a VACC device to complete."""
_lazy_init()
with torch_vacc.vacc.device(device):
return _torch_vacc._device_synchronize()
# Memory management (https://pytorch.org/docs/stable/cuda.html#memory-management)

View File

@@ -0,0 +1,26 @@
from typing import List
import torch
from torch_vacc._vacc_libs import _torch_vacc
from .grad_scaler import OptState, GradScaler
from .autocast_mode import autocast, custom_fwd, custom_bwd
def get_amp_supported_dtype() -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
def is_autocast_enabled() -> bool:
return _torch_vacc.is_autocast_enabled()
def set_autocast_enabled(enable: bool):
_torch_vacc.set_autocast_enabled(enable)
def get_autocast_dtype() -> torch.dtype:
return _torch_vacc.get_autocast_dtype()
def set_autocast_dtype(dtype: torch.dtype):
return _torch_vacc.set_autocast_dtype(dtype)

Binary file not shown.

View File

@@ -0,0 +1,144 @@
import collections
import functools
from typing import Any
import torch
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
class autocast(torch.amp.autocast_mode.autocast):
r"""See :class:`torch.autocast`.
``torch.vacc.amp.autocast(args...)`` is equivalent to ``torch.autocast("vacc", args...)``
"""
def __init__(
self,
enabled: bool = True,
dtype: torch.dtype = torch.float16,
cache_enabled: bool = True,
):
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = "vacc"
self.fast_dtype = dtype
return
super().__init__(
"vacc", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
)
def __enter__(self):
if torch._jit_internal.is_scripting():
return self
return super().__enter__()
# TODO: discuss a unified TorchScript-friendly API for autocast
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
return
return super().__exit__(exc_type, exc_val, exc_tb)
def __call__(self, func):
if torch._jit_internal.is_scripting():
return func
return super().__call__(func)
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
# may be falsely detected as "Iterables."
def _cast(value, dtype):
if isinstance(value, torch.Tensor):
is_eligible = (
value.is_floating_point()
and value.is_vacc
and (value.dtype is not torch.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value
# custom_fwd is a decorator that may or may not be used with arguments, following
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
# this works:
# @custom_fwd
# def forward(...):
# this also works:
# @custom_fwd(cast_inputs=torch.float)
# def forward(...):
def custom_fwd(fwd=None, *, cast_inputs=None):
"""
Create a helper decorator for ``forward`` methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point VACC Tensors to the target dtype (non-floating-point Tensors are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.
.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
"""
if fwd is None:
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.get_autocast_gpu_dtype()
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
return fwd(*args, **kwargs)
else:
autocast_context = torch.is_autocast_enabled()
args[0]._fwd_used_autocast = False
if autocast_context:
with autocast(enabled=False):
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
else:
return fwd(*args, **kwargs)
return decorate_fwd
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
# cast_inputs supplied to custom_fwd.
def custom_bwd(bwd):
"""Create a helper decorator for backward methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
Ensures that ``backward`` executes with the same autocast state as ``forward``.
See the :ref:`example page<amp-custom-examples>` for more detail.
"""
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
return bwd(*args, **kwargs)
return decorate_bwd

View File

@@ -0,0 +1,7 @@
import torch
__all__ = ["amp_definitely_not_available"]
def amp_definitely_not_available():
return not torch.vacc.is_available()

View File

@@ -0,0 +1,667 @@
import inspect
import warnings
from collections import abc, defaultdict
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Tuple
import torch
from .common import amp_definitely_not_available
__all__ = ["OptState", "GradScaler"]
class _MultiDeviceReplicator:
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert (
master_tensor.is_cuda
or master_tensor.device.type == "xla"
or master_tensor.device.type == "vacc"
)
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
# as well as associated "enum" values. Prefers defining these at top level because
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
# causes a circular reference, which we'd rather avoid.
class OptState(Enum):
READY = 0
UNSCALED = 1
STEPPED = 2
def _refresh_per_optimizer_state():
return {"stage": OptState.READY, "found_inf_per_device": {}}
class GradScaler:
_scale: Optional[torch.Tensor]
_grows_tracker: Optional[torch.Tensor]
_per_optimizer_states: Dict[int, Dict[str, Any]]
"""
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently.
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
* ``scaler.update()`` updates ``scaler``'s scale factor.
Example::
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
and multiple losses/optimizers.
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
without incurring inf or NaN gradient values.
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
Args:
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
Default: ``True``
"""
def __init__(
self,
init_scale=2.0**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True,
):
if enabled and amp_definitely_not_available():
warnings.warn(
"torch.vacc.amp.GradScaler is enabled, but VACC device is not available. Disabling."
)
self._enabled = False
else:
self._enabled = enabled
if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _check_scale_growth_tracker(
self, funcname
) -> Tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, (
f"Attempted {funcname} but _scale is None. " + fix
)
assert self._growth_tracker is not None, (
f"Attempted {funcname} but _growth_tracker is None. " + fix
)
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self, dev):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full(
(), self._init_growth_tracker, dtype=torch.int32, device=dev
)
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert (
outputs.is_cuda
or outputs.device.type == "xla"
or outputs.device.type == "vacc"
)
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[
_MultiDeviceReplicator
] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val):
if isinstance(val, torch.Tensor):
assert (
val.is_cuda or val.device.type == "xla" or val.device.type == "vacc"
)
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device),
)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. note::
:meth:`unscale_` does not incur a CPU-GPU sync.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
retval = optimizer.step(*args, **kwargs)
return retval
def step(self, optimizer, *args, **kwargs):
"""
:meth:`step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Args:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments.
kwargs: Any keyword arguments.
.. warning::
Closure use is not currently supported.
"""
if not self._enabled:
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
raise RuntimeError(
"Closure use is not currently supported if GradScaler is enabled."
)
self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError(
"step() has already been called since the last update()."
)
retval = None
if (
hasattr(optimizer, "_step_supports_amp_scaling")
and optimizer._step_supports_amp_scaling
):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
# it can query its own state, invoke unscale_ on itself, etc
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
# to skip the parameter updates or unscale gradients before updating parameters in
# the fused kernel, e.g. `FusedAdamMathFunctor`.
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
# while the method is expected to be called by users side, i.e. their optimizers.
kwargs_ = kwargs
has_grad_scaler_kwarg = (
"grad_scaler" in inspect.signature(optimizer.step).parameters
)
if has_grad_scaler_kwarg:
warnings.warn(
"GradScaler is going to stop passing itself as a keyword argument to the passed "
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
FutureWarning,
)
kwargs_.update({"grad_scaler": self})
else:
if optimizer_state["stage"] is OptState.READY:
self._check_inf_per_device(optimizer)
scaler = self._get_scale_async()
found_inf = cast(
torch.Tensor,
sum(
[
t.to(scaler.device, non_blocking=True)
for t in optimizer_state["found_inf_per_device"].values()
]
),
)
optimizer.grad_scale = (
None if optimizer_state["stage"] == OptState.UNSCALED else scaler
)
optimizer.found_inf = found_inf
retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg:
del optimizer.grad_scale
del optimizer.found_inf
return retval
if optimizer_state["stage"] is OptState.READY:
self.unscale_(optimizer)
assert (
len(optimizer_state["found_inf_per_device"]) > 0
), "No inf checks were recorded for this optimizer."
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
optimizer_state["stage"] = OptState.STEPPED
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.vacc.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
.. warning::
For performance reasons, we do not check the scale factor value to avoid synchronizations,
so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or
you are seeing NaNs in your gradients or loss, something is likely wrong. For example,
bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.vacc.FloatTensor with requires_grad=False."
# assert isinstance(new_scale, torch.vacc.FloatTensor), reason # type: ignore[attr-defined]
assert (
isinstance(new_scale, torch.Tensor)
and new_scale.dtype == torch.float32
), reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
torch._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self):
return self._scale
def get_scale(self):
"""
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning::
:meth:`get_scale` incurs a CPU-GPU sync.
"""
if self._enabled:
return (
self._init_scale
if self._scale is None
else self._get_scale_async().item()
)
else:
return 1.0
def get_growth_factor(self):
r"""
Returns a Python float containing the scale growth factor.
"""
return self._growth_factor
def set_growth_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale growth factor.
"""
self._growth_factor = new_factor
def get_backoff_factor(self):
r"""
Returns a Python float containing the scale backoff factor.
"""
return self._backoff_factor
def set_backoff_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale backoff factor.
"""
self._backoff_factor = new_factor
def get_growth_interval(self):
r"""
Returns a Python int containing the growth interval.
"""
return self._growth_interval
def set_growth_interval(self, new_interval):
r"""
Args:
new_interval (int): Value to use as the new growth interval.
"""
self._growth_interval = new_interval
def _get_growth_tracker(self):
if self._enabled:
return (
self._init_growth_tracker
if self._growth_tracker is None
else self._growth_tracker.item()
)
else:
return 0
def is_enabled(self):
r"""
Returns a bool indicating whether this instance is enabled.
"""
return self._enabled
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return (
{
"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker(),
}
if self._enabled
else {}
)
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
if not self._enabled:
return
if len(state_dict) == 0:
raise RuntimeError(
"The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler."
)
self._init_scale = state_dict["scale"]
if self._scale is not None:
self._scale.fill_(state_dict["scale"])
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._growth_interval = state_dict["growth_interval"]
self._init_growth_tracker = state_dict["_growth_tracker"]
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self):
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, (
"A GradScaler instance may only be pickled at the beginning "
"of an iteration, or at the end after scaler.update()."
)
# Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily.
state["_init_scale"] = self.get_scale()
state["_init_growth_tracker"] = self._get_growth_tracker()
state["_scale"] = None
state["_growth_tracker"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
self._per_optimizer_states[id(optimizer)][
"found_inf_per_device"
] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer):
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,306 @@
from typing import Tuple, Union, Optional, List
import torch
from torch.nn import functional as F
def split_last_two_dims_into_blocks(x, h, w):
leading_dims = x.shape[:-2]
H, W = x.shape[-2:]
assert (
H % h == 0 and W % w == 0
), "The last two dimensions must be divisible by block size."
x_reshaped = x.view(-1, 1, H, W)
unfolded = F.unfold(x_reshaped, kernel_size=(h, w), stride=(h, w))
unfolded = unfolded.view(-1, 1, h, w, H // h, W // w)
unfolded = unfolded.permute(0, 1, 4, 5, 2, 3)
final_shape = leading_dims + (H // h, W // w, h, w)
result = unfolded.view(final_shape)
return result
def merge_blocks_to_original_layout(x, h, w):
leading_dims = x.shape[:-4]
H_div_h, W_div_w, h, w = x.shape[-4:]
H = H_div_h * h
W = W_div_w * w
x_reshaped = x.view(-1, 1, H_div_h, W_div_w, h, w)
x_reshaped = x_reshaped.permute(0, 1, 4, 5, 2, 3)
x_reshaped = x_reshaped.view(-1, h * w, H_div_h * W_div_w)
folded = F.fold(x_reshaped, output_size=(H, W), kernel_size=(h, w), stride=(h, w))
final_shape = leading_dims + (H, W)
result = folded.view(final_shape)
return result
def w8a8_block_fp8_matmul(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: Optional[torch.Tensor],
weight_scale: Optional[torch.Tensor],
block_size: List[int],
is_linear_weight: bool = False,
output_opt: Optional[torch.Tensor] = None,
**kwargs
):
b0, b1 = block_size
dim0, dim1 = weight.shape
dim0pad, dim1pad = 0, 0
if dim0 % b0 != 0:
dim0pad = b0 - dim0 % b0
if dim1 % b1 != 0:
dim1pad = b1 - dim1 % b1
dim0_origin, dim1_origin = dim0, dim1
dim0 += dim0pad
dim1 += dim1pad
bs0, bs1 = dim0 // b0, dim1 // b1
weight_dequant = torch.nn.functional.pad(weight, (0, dim1pad, 0, dim0pad), value=0)
weight_dequant = weight_dequant.cpu().view(bs0, b0, bs1, b1).permute(
0, 2, 1, 3
).reshape(bs0, bs1, -1).float().to(input.device) * weight_scale.unsqueeze(-1)
weight_dequant = (
weight_dequant.reshape(bs0, bs1, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(input.dtype)
)
weight_dequant = weight_dequant[:dim0_origin, :dim1_origin]
output = torch.matmul(
input, weight_dequant.T if is_linear_weight else weight_dequant
)
if output_opt is not None:
output = output_opt.copy_(output)
return output
def w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: Optional[torch.Tensor],
weight_scale: Optional[torch.Tensor],
block_size: List[int],
**kwargs
):
assert input_scale is None, "w8a8_block_fp8_matmul only support quant weight now"
return w8a8_block_fp8_matmul(
input, weight, None, weight_scale, block_size, is_linear_weight=True
)
def fused_experts(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = True,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
decode_with_batch: bool = False,
) -> torch.Tensor:
batch_seq_all, hidden_dims = hidden_states.shape
intermediate_size = w2_weight.shape[-1]
num_experts = w13_weight.shape[0]
w13_weight = w13_weight.contiguous()
w2_weight = w2_weight.contiguous()
w13_scale = w13_scale.contiguous()
w2_scale = w2_scale.contiguous()
final_hidden_states = torch.zeros_like(hidden_states)
import torch.nn.functional as F
w1_scale = w13_scale
w2_scale = w2_scale
_, bs0_w13, bs1_w13 = w1_scale.shape
_, bs0_w2, bs1_w2 = w2_scale.shape
sel_experts = topk_ids.shape[1]
if hidden_states.shape[0] == 1:
for id in range(sel_experts):
expert_idx = topk_ids[0][id]
expert_w1 = w13_weight[expert_idx].contiguous()
expert_w2 = w2_weight[expert_idx].contiguous()
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
dim0, dim1 = expert_w1.shape
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
expert_w1 = (
expert_w1
.view(bs0_w13, b0, bs1_w13, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w13, bs1_w13, -1)
.float()
.to(hidden_states.device)
* ws1
)
expert_w1 = (
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
dim0, dim1 = expert_w2.shape
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
# assert (bs0, bs1, 1)==ws2.shape
expert_w2 = (
expert_w2
.view(bs0_w2, b0, bs1_w2, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w2, bs1_w2, -1)
.float()
.to(hidden_states.device)
* ws2
)
expert_w2 = (
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
x = hidden_states
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
final_hidden_states += current_hidden_states
else:
for expert_idx in range(num_experts):
# topk_ids [tokens, experts] => sample:[10, 8]
# expert_mask [tokens, experts] => sample:[10, 8]
expert_mask = topk_ids == expert_idx
idx = torch.where(expert_mask)[0]
if idx.numel() == 0:
continue
expert_w1 = w13_weight[expert_idx].contiguous()
expert_w2 = w2_weight[expert_idx].contiguous()
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
dim0, dim1 = expert_w1.shape
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
expert_w1 = (
expert_w1
.view(bs0_w13, b0, bs1_w13, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w13, bs1_w13, -1)
.float()
.to(hidden_states.device)
* ws1
)
expert_w1 = (
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
dim0, dim1 = expert_w2.shape
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
# assert (bs0, bs1, 1)==ws2.shape
expert_w2 = (
expert_w2
.view(bs0_w2, b0, bs1_w2, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w2, bs1_w2, -1)
.float()
.to(hidden_states.device)
* ws2
)
expert_w2 = (
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
# [seq, experts]
expert_weights = (
topk_weights.masked_select(expert_mask)
.unsqueeze(1)
.to(hidden_states.dtype)
)
x = hidden_states[idx]
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
# final_hidden_states[idx] += current_hidden_states
final_hidden_states.index_add_(0, idx, current_hidden_states)
final_hidden_states = final_hidden_states.reshape(batch_seq_all, hidden_dims)
return final_hidden_states
def fused_mlp_mm_fp8(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
use_fp8_w8a8: bool = True,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape_w13: Optional[List[int]] = None,
block_shape_w2: Optional[List[int]] = None,
):
def fp8_to_fp16(inp, scale, block_size, trans_type):
inp_t = inp.to(trans_type)
inp_t = split_last_two_dims_into_blocks(inp_t, block_size[0], block_size[1])
assert scale.size(0) == inp_t.size(-4)
assert scale.size(1) == inp_t.size(-3)
inp_t = inp_t * scale.unsqueeze(-1).unsqueeze(-1)
inp_t = merge_blocks_to_original_layout(inp_t, block_size[0], block_size[1])
return inp_t.to(trans_type)
w13_weight = w13_weight.contiguous()
w2_weight = w2_weight.contiguous()
w13_scale = w13_scale.contiguous()
w2_scale = w2_scale.contiguous()
w13_fp = fp8_to_fp16(w13_weight, w13_scale, block_shape_w13, hidden_states.dtype)
w2_fp = fp8_to_fp16(w2_weight, w2_scale, block_shape_w2, hidden_states.dtype)
out = hidden_states @ w13_fp
out = torch.chunk(out, 2, dim=-1)
out = F.silu(out[0]) * out[1]
out = out @ w2_fp
return out
def mla_matmul_scale(input: torch.Tensor, weight: torch.Tensor, scale: float):
output = torch.matmul(input, weight)
output = output * scale
output = output.to(input.dtype)
return output
def mla_matmul(input: torch.Tensor, weight: torch.Tensor):
output = torch.matmul(input, weight)
output = output.to(input.dtype)
return output

View File

@@ -0,0 +1,146 @@
from typing import List, Optional, Tuple, Union
import torch
from torch import Generator
from torch_vacc._vacc_libs import _torch_vacc
def fuse_moe_prefill_stage0_qwen(
hidden_states,
rms_residual,
rms_weight,
gate_weight,
rms_hidden_state_opt: Optional[torch.Tensor] = None,
zero_moe_hidden_state_opt: Optional[torch.Tensor] = None,
topk_ids_opt: Optional[torch.Tensor] = None,
topk_weight_opt: Optional[torch.Tensor] = None,
):
return _torch_vacc.fuse_moe_prefill_stage0_qwen(
hidden_states,
rms_residual,
rms_weight,
gate_weight,
rms_hidden_state_opt,
zero_moe_hidden_state_opt,
topk_ids_opt,
topk_weight_opt,
)
def fuse_moe_decode_qwen(
hidden_states,
rms_residual,
rms_weight,
moe_weight_13,
moe_weight_2,
moe_weight_13_dequat,
moe_weight_2_dequant,
gate_weight,
block_size_13,
block_size_2,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] = None,
output: Optional[torch.Tensor] = None,
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fuse_moe_decode_qwen(
hidden_states,
rms_residual,
rms_weight,
moe_weight_13,
moe_weight_2,
moe_weight_13_dequat,
moe_weight_2_dequant,
gate_weight,
block_size_13,
block_size_2,
world_size,
rank,
group_id,
dev_info,
output,
)
def rot_pos_emb_qwenvl(grid_thw: List[List[int]],
hidden_size: int,
head_num: int,
spatial_merge_size: int,
dtype: torch.dtype,
device: Union[int, str, torch.device] = "vacc"):
#assert out_tensor.device.type == "vacc", f"please target vacc device, now is {out_tensor.device}"
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("vacc", device)
thws = []
for i in grid_thw:
thws.extend(i)
return _torch_vacc.rot_pos_emb_qwenvl(thws,
hidden_size,
head_num,
spatial_merge_size,
dtype,
device)
def fast_pos_embed_interpolate_qwenvl(weight: torch.Tensor,
grid_thw: List[List[int]],
num_grid_per_side: int,
spatial_merge_size: int,
hidden_dim: int):
thws = []
for i in grid_thw:
thws.extend(i)
return _torch_vacc.fast_pos_embed_interpolate_qwenvl(weight,
thws,
num_grid_per_side,
spatial_merge_size,
hidden_dim)
# qwen2_vl and qwen3_vl img preocess op is same
def qwen2vl_img_preprocess(
image: "torch.Tensor",
do_resize: bool,
min_pixels: int,
max_pixels: int,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
resized_height: int,
resized_width: int,
interpolation: int, #Optional["F.InterpolationMode"],
patch_size: int,
temporal_patch_size: int,
merge_size: int,
image_mean0: float,
image_mean1: float,
image_mean2: float,
image_std0: float,
image_std1: float,
image_std2: float,
# batch_size: int = 1,
# grid_t: int = 1,
# channel: int = 3,
# output: Optional[torch.Tensor] = None
):
assert image.device.type == "vacc", f"please target vacc device, now is {image.device}"
return _torch_vacc.qwen2vl_img_preprocess(
image,
do_resize,
min_pixels,
max_pixels,
do_rescale,
rescale_factor,
do_normalize,
resized_height,
resized_width,
interpolation,
patch_size,
temporal_patch_size,
merge_size,
image_mean0, image_mean1, image_mean2,
image_std0, image_std1, image_std2
)

View File

@@ -0,0 +1,107 @@
import threading
import traceback
from typing import List
from .._vacc_libs import _torch_vacc
_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls = []
_is_in_bad_fork = getattr(_torch_vacc, "_vacc_in_bad_fork", lambda: False)
def is_initialized():
r"""Returns whether PyTorch's VACC state has been initialized."""
return _initialized and not _is_in_bad_fork()
class _LazySeedTracker:
# Since seeding is memory-less, only track the latest seed.
# Note: `manual_seed_all` followed by `manual_seed` overwrites
# the seed on current device. We track the order of **latest**
# calls between these two API.
def __init__(self):
self.manual_seed_all_cb = None
self.manual_seed_cb = None
self.call_order = []
def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback)
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback)
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
def get_calls(self) -> List:
return self.call_order
_lazy_seed_tracker = _LazySeedTracker()
def _lazy_call(callable, **kwargs):
if is_initialized():
callable()
else:
# TODO(torch_deploy): this accesses linecache, which attempts to read the
# file system to get traceback info. Patch linecache or do something
# else here if this ends up being important.
global _lazy_seed_tracker
if kwargs.get("seed_all", False):
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
elif kwargs.get("seed", False):
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
else:
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))
class DeferredVaccCallError(Exception):
pass
def _lazy_init():
"""Initialize VACC device state."""
global _initialized, _queued_calls
if _initialized or hasattr(_tls, "is_initializing"):
return
with _initialization_lock:
if _initialized:
return
# It is important to prevent other threads from entering _lazy_init
# immediately, while we are still guaranteed to have the GIL, because some
# of the C calls we make below will release the GIL
if _is_in_bad_fork():
raise RuntimeError(
"Cannot re-initialize VACC in forked subprocess. To use VACC with "
"multiprocessing, you must use the 'spawn' start method"
)
_torch_vacc._vacc_init()
_tls.is_initializing = True
for calls in _lazy_seed_tracker.get_calls():
if calls:
_queued_calls.append(calls)
try:
for queued_call, orig_traceback in _queued_calls:
try:
queued_call()
except Exception as e:
msg = (
f"VACC call failed lazily at initialization with error: {str(e)}\n\n"
f"VACC call was originally invoked at:\n\n{''.join(orig_traceback)}"
)
raise DeferredVaccCallError(msg) from e
finally:
delattr(_tls, "is_initializing")
_initialized = True

535
torch_vacc/vacc/memory.py Normal file
View File

@@ -0,0 +1,535 @@
import collections
import contextlib
import warnings
from typing import Tuple
import torch
from torch._utils import _get_device_index
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
from .lazy_initialize import is_initialized, _lazy_init
__all__ = [
"mem_get_info",
# "caching_allocator_alloc",
# "caching_allocator_delete",
"set_per_process_memory_fraction",
"empty_cache",
"memory_stats",
"memory_stats_as_nested_dict",
"reset_accumulated_memory_stats",
"reset_peak_memory_stats",
"reset_max_memory_allocated",
"reset_max_memory_cached",
"memory_allocated",
"max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"memory_cached",
"max_memory_cached",
"memory_snapshot",
"memory_summary",
"get_allocator_backend",
]
@contextlib.contextmanager
def _free_mutex():
_torch_vacc._vacc_lock_mutex()
try:
yield
finally:
_torch_vacc._vacc_unlock_mutex()
# def caching_allocator_alloc(size, device=None, stream=None):
# r"""Performs a memory allocation using the VACC memory allocator.
# Memory is allocated for a given device and a stream, this
# function is intended to be used for interoperability with other
# frameworks. Allocated memory is released through
# :func:`~torch_vacc.vacc.caching_allocator_delete`.
# Arguments:
# size (int): number of bytes to be allocated.
# device (torch.device or int, optional): selected device. If it is
# ``None`` the default VACC device is used.
# stream (torch_vacc.vacc.Stream or int, optional): selected stream. If is ``None`` then
# the default stream for the selected device is used.
# """
# if device is None:
# device = torch_vacc.vacc.current_device()
# device = _get_device_index(device)
# if stream is None:
# stream = torch_vacc.vacc.current_stream(device)
# if isinstance(stream, torch_vacc.vacc.streams.Stream):
# stream = stream.vacc_stream
# if not isinstance(stream, int):
# raise TypeError(
# "Invalid type for stream argument, must be "
# "`torch_vacc.vacc.Stream` or `int` representing a pointer "
# "to a exisiting stream"
# )
# with torch_vacc.vacc.device(device):
# return _torch_vacc._vacc_vaccCachingAllocator_raw_alloc(size, stream)
# def caching_allocator_delete(mem_ptr):
# r"""Deletes memory allocated using the VACC memory allocator.
# Memory allocated with :func:`~torch_vacc.vacc.caching_allocator_alloc`.
# is freed here. The associated device and stream are tracked inside
# the allocator.
# Arguments:
# mem_ptr (int): memory address to be freed by the allocator.
# """
# _torch_vacc._vacc_vaccCachingAllocator_raw_delete(mem_ptr)
def set_per_process_memory_fraction(fraction, device=None) -> None:
r"""Set memory fraction for a process.
The fraction is used to limit an caching allocator to allocated memory on a VACC device.
The allowed value equals the total visible memory multiplied fraction.
If trying to allocate more than the allowed value in a process, will raise an out of
memory error in allocator.
Arguments:
fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
device (torch.device or int, optional): selected device. If it is
``None`` the default VACC device is used.
.. note::
In general, the total available free memory is less than the total capacity.
"""
_lazy_init()
if device is None:
device = torch_vacc.vacc.current_device()
device = _get_device_index(device)
if not isinstance(fraction, float):
raise TypeError("Invalid type for fraction argument, must be `float`")
if fraction < 0 or fraction > 1:
raise ValueError(
"Invalid fraction value: {}. " "Allowed range: 0~1".format(fraction)
)
_torch_vacc._vacc_setMemoryFraction(fraction, device)
def empty_cache():
r"""Releases all unoccupied cached memory currently held by the caching
allocator so that those can be used in other VACC application and visible in
`nvidia-smi`.
.. note::
:func:`~torch_vacc.vacc.empty_cache` doesn't increase the amount of VACC
memory available for PyTorch. However, it may help reduce fragmentation
of VACC memory in certain cases.
"""
if is_initialized():
_torch_vacc._vacc_emptyCache()
def memory_stats(device=None):
"""Returns a dictionary of VACC memory allocator statistics for a
given device.
The return value of this function is a dictionary of statistics, each of
which is a non-negative integer.
Core statistics:
- ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
number of allocation requests received by the memory allocator.
- ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
amount of allocated memory.
- ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
number of reserved segments from ``vaccMalloc()``.
- ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
amount of reserved memory.
- ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
number of active memory blocks.
- ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
amount of active memory.
- ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
number of inactive, non-releasable memory blocks.
- ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
amount of inactive, non-releasable memory.
For these core statistics, values are broken down as follows.
Pool type:
- ``all``: combined statistics across all memory pools.
- ``large_pool``: statistics for the large allocation pool
(as of October 2019, for size >= 1MB allocations).
- ``small_pool``: statistics for the small allocation pool
(as of October 2019, for size < 1MB allocations).
Metric type:
- ``current``: current value of this metric.
- ``peak``: maximum value of this metric.
- ``allocated``: historical total increase in this metric.
- ``freed``: historical total decrease in this metric.
In addition to the core statistics, we also provide some simple event
counters:
- ``"num_alloc_retries"``: number of failed ``vaccMalloc`` calls that
result in a cache flush and retry.
- ``"num_ooms"``: number of out-of-memory errors thrown.
The caching allocator can be configured via ENV to not split blocks larger than a
defined size (see Memory Management section of the Cuda Semantics documentation).
This helps avoid memory framentation but may have a performance
penalty. Additional outputs to assist with tuning and evaluating impact:
- ``"max_split_size"``: blocks above this size will not be split.
- ``"oversize_allocations.{current,peak,allocated,freed}"``:
number of over-size allocation requests received by the memory allocator.
- ``"oversize_segments.{current,peak,allocated,freed}"``:
number of over-size reserved segments from ``cudaMalloc()``.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistics for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
result = []
def _recurse_add_to_result(prefix, obj):
if isinstance(obj, dict):
if len(prefix) > 0:
prefix += "."
for k, v in obj.items():
_recurse_add_to_result(prefix + k, v)
else:
result.append((prefix, obj))
stats = memory_stats_as_nested_dict(device=device)
_recurse_add_to_result("", stats)
result.sort()
return collections.OrderedDict(result)
def memory_stats_as_nested_dict(device=None):
r"""Returns the result of :func:`~torch_vacc.vacc.memory_stats` as a nested dictionary."""
device = _get_device_index(device, optional=True)
return _torch_vacc._vacc_memoryStats(device)
def reset_accumulated_memory_stats(device=None):
r"""Resets the "accumulated" (historical) stats tracked by the VACC memory allocator.
See :func:`~torch_vacc.vacc.memory_stats` for details. Accumulated stats correspond to
the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
`"num_alloc_retries"` and `"num_ooms"`.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
device = _get_device_index(device, optional=True)
return _torch_vacc._vacc_resetAccumulatedMemoryStats(device)
def reset_peak_memory_stats(device=None):
r"""Resets the "peak" stats tracked by the VACC memory allocator.
See :func:`~torch_vacc.vacc.memory_stats` for details. Peak stats correspond to the
`"peak"` key in each individual stat dict.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
device = _get_device_index(device, optional=True)
return _torch_vacc._vacc_resetPeakMemoryStats(device)
def reset_max_memory_allocated(device=None):
r"""Resets the starting point in tracking maximum VACC memory occupied by
tensors for a given device.
See :func:`~torch_vacc.vacc.max_memory_allocated` for details.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
.. warning::
This function now calls :func:`~torch_vacc.vacc.reset_peak_memory_stats`, which resets
/all/ peak memory stats.
"""
# warnings.warn(
# "torch_vacc.vacc.reset_max_memory_allocated now calls torch_vacc.vacc.reset_peak_memory_stats, "
# "which resets /all/ peak memory stats.",
# DeprecationWarning,
# )
return reset_peak_memory_stats(device=device)
def reset_max_memory_cached(device=None):
r"""Resets the starting point in tracking maximum VACC memory managed by the
caching allocator for a given device.
See :func:`~torch_vacc.vacc.max_memory_cached` for details.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
.. warning::
This function now calls :func:`~torch_vacc.vacc.reset_peak_memory_stats`, which resets
/all/ peak memory stats.
"""
# warnings.warn(
# "torch_vacc.vacc.reset_max_memory_cached now calls torch_vacc.vacc.reset_peak_memory_stats, "
# "which resets /all/ peak memory stats.",
# DeprecationWarning,
# )
return reset_peak_memory_stats(device=device)
def memory_allocated(device=None):
r"""Returns the current VACC memory occupied by tensors in bytes for a given
device.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
return memory_stats(device=device)["allocated_bytes.all.current"]
def max_memory_allocated(device=None):
r"""Returns the maximum VACC memory occupied by tensors in bytes for a given
device.
By default, this returns the peak allocated memory since the beginning of
this program. :func:`~torch_vacc.vacc.reset_peak_stats` can be used to
reset the starting point in tracking this metric. For example, these two
functions can measure the peak allocated memory usage of each iteration in a
training loop.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
return memory_stats(device=device)["allocated_bytes.all.peak"]
def memory_reserved(device=None):
r"""Returns the current VACC memory managed by the caching allocator in bytes
for a given device.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
return memory_stats(device=device)["reserved_bytes.all.current"]
def max_memory_reserved(device=None):
r"""Returns the maximum VACC memory managed by the caching allocator in bytes
for a given device.
By default, this returns the peak cached memory since the beginning of this
program. :func:`~torch_vacc.vacc.reset_peak_stats` can be used to reset
the starting point in tracking this metric. For example, these two functions
can measure the peak cached memory amount of each iteration in a training
loop.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
return memory_stats(device=device)["reserved_bytes.all.peak"]
def memory_cached(device=None):
r"""Deprecated; see :func:`~torch_vacc.vacc.memory_reserved`."""
# warnings.warn(
# "torch_vacc.vacc.memory_cached has been renamed to torch_vacc.vacc.memory_reserved",
# DeprecationWarning,
# )
return memory_reserved(device=device)
def max_memory_cached(device=None):
r"""Deprecated; see :func:`~torch_vacc.vacc.max_memory_reserved`."""
# warnings.warn(
# "torch_vacc.vacc.max_memory_cached has been renamed to torch_vacc.vacc.max_memory_reserved",
# DeprecationWarning,
# )
return max_memory_reserved(device=device)
def memory_snapshot():
r"""Returns a snapshot of the VACC memory allocator state across all devices.
Interpreting the output of this function requires familiarity with the
memory allocator internals.
"""
return _torch_vacc._vacc_memorySnapshot()
def _format_size(sz, pref_sz):
prefixes = ["B ", "KB", "MB", "GB", "TB", "PB"]
prefix = prefixes[0]
for new_prefix in prefixes[1:]:
if pref_sz < 768 * 1024:
break
prefix = new_prefix
sz //= 1024
pref_sz /= 1024
return "{:7d} {}".format(sz, prefix)
def _format_count(cnt, pref_cnt):
prefixes = [" ", "K", "M"]
prefix = prefixes[0]
for new_prefix in prefixes[1:]:
if pref_cnt < 750 * 1000:
break
prefix = new_prefix
cnt //= 1000
pref_cnt /= 1000
return "{:7d} {} ".format(cnt, prefix)
def create_metrics_to_display():
metrics_to_display = [
("allocated_bytes", "Allocated memory", _format_size),
("active_bytes", "Active memory", _format_size),
("reserved_bytes", "VACC reserved memory", _format_size),
("inactive_split_bytes", "Non-releasable memory", _format_size),
("allocation", "Allocations", _format_count),
("active", "Active allocs", _format_count),
("segment", "VACC reserved segments", _format_count),
("inactive_split", "Non-releasable allocs", _format_count),
]
lines = []
lines.append("=" * 75)
lines.append(" {_:16} PyTorch VACC memory summary, device ID {device:<18d} ")
lines.append("-" * 75)
lines.append(
" {_:9} VACC OOMs: {num_ooms:<13d} | {_:6} vaccMalloc retries: {num_alloc_retries:<9d} "
)
lines.append("=" * 75)
lines.append(
" Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed "
)
return metrics_to_display, lines
def memory_summary(device=None, abbreviated=False):
r"""Returns a human-readable printout of the current memory allocator
statistics for a given device.
This can be useful to display periodically during training, or when
handling out-of-memory exceptions.
Arguments:
device (torch.device or int, optional): selected device. Returns
printout for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
abbreviated (bool, optional): whether to return an abbreviated summary
(default: False).
"""
device = _get_device_index(device, optional=True)
stats = memory_stats(device=device)
metrics_to_display, lines = create_metrics_to_display()
for metric_key, metric_name, formatter in metrics_to_display:
lines.append("-" * 75)
submetrics = [("all", metric_name)]
if not abbreviated:
submetrics.append(("large_pool", " from large pool"))
submetrics.append(("small_pool", " from small pool"))
current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
None,
None,
None,
None,
)
for submetric_key, submetric_name in submetrics:
prefix = metric_key + "." + submetric_key + "."
current = stats[prefix + "current"]
peak = stats[prefix + "peak"]
allocated = stats[prefix + "allocated"]
freed = stats[prefix + "freed"]
if current_prefval is None:
current_prefval = current
peak_prefval = peak
allocated_prefval = allocated
freed_prefval = freed
lines.append(
" {:<21} | {} | {} | {} | {} ".format(
submetric_name,
formatter(current, current_prefval),
formatter(peak, peak_prefval),
formatter(allocated, allocated_prefval),
formatter(freed, freed_prefval),
),
)
metrics_to_display = [
("oversize_allocations", "Oversize allocations", _format_count),
("oversize_segments", "Oversize VACC segments", _format_count),
]
for metric_key, metric_name, formatter in metrics_to_display:
lines.append("-" * 75)
prefix = metric_key + "."
current = stats[prefix + "current"]
peak = stats[prefix + "peak"]
allocated = stats[prefix + "allocated"]
freed = stats[prefix + "freed"]
lines.append(
" {:<21} | {} | {} | {} | {} ".format(
metric_name,
formatter(current, current),
formatter(peak, peak),
formatter(allocated, allocated),
formatter(freed, freed),
),
)
lines.append("=" * 75)
fmt_dict = {"_": "", "device": device}
for k, v in stats.items():
fmt_dict[k.replace(".", "-")] = v
return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
def mem_get_info(device=None) -> Tuple[int, int]:
r"""Returns the global free and total VACC memory for a given
device using vaccrtMemGetInfo.
Args:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
_lazy_init()
if device is None:
device = torch_vacc.vacc.current_device()
device = _get_device_index(device)
return _torch_vacc._vacc_getDeviceMemories(device)
def get_allocator_backend() -> str:
r"""Returns a string describing the active allocator backend as set by
``PYTORCH_VACC_ALLOC_CONF``. Currently available backends are
``native`` (PyTorch's native caching allocator).
"""
return _torch_vacc._vacc_getAllocatorBackend()

179
torch_vacc/vacc/random.py Normal file
View File

@@ -0,0 +1,179 @@
from typing import Union, List, Iterable
import torch
from torch import Tensor
from . import _lazy_call, _lazy_init, current_device, device_count
__all__ = [
"get_rng_state",
"get_rng_state_all",
"set_rng_state",
"set_rng_state_all",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"initial_seed",
]
# Random Number Generator related functions (https://pytorch.org/docs/stable/cuda.html#random-number-generator)
def get_rng_state(device: Union[int, str, torch.device] = "vacc") -> Tensor:
r"""Returns the random number generator state of the specified GPU as a ByteTensor.
Args:
device (torch.device or int, optional): The device to return the RNG state of.
Default: ``'vacc'`` (i.e., ``torch.device('vacc')``, the current VACC device).
.. warning::
This function eagerly initializes VACC.
"""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("vacc", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
return default_generator.get_state()
def get_rng_state_all() -> List[Tensor]:
r"""Returns a list of ByteTensor representing the random number states of all devices."""
results = []
for i in range(device_count()):
results.append(get_rng_state(i))
return results
def set_rng_state(
new_state: Tensor, device: Union[int, str, torch.device] = "vacc"
) -> None:
r"""Sets the random number generator state of the specified GPU.
Args:
new_state (torch.ByteTensor): The desired state
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'vacc'`` (i.e., ``torch.device('vacc')``, the current VACC device).
"""
with torch._C._DisableFuncTorch():
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("vacc", device)
def cb():
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
default_generator.set_state(new_state_copy)
_lazy_call(cb)
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
r"""Sets the random number generator state of all devices.
Args:
new_states (Iterable of torch.ByteTensor): The desired state for each device"""
for i, state in enumerate(new_states):
set_rng_state(state, i)
def manual_seed(seed: int) -> None:
r"""Sets the seed for generating random numbers for the current GPU.
It's safe to call this function if VACC is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
.. warning::
If you are working with a multi-GPU model, this function is insufficient
to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
"""
seed = int(seed)
def cb():
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
default_generator.manual_seed(seed)
_lazy_call(cb, seed=True)
def manual_seed_all(seed: int) -> None:
r"""Sets the seed for generating random numbers on all GPUs.
It's safe to call this function if VACC is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
"""
seed = int(seed)
def cb():
for i in range(device_count()):
default_generator = torch.vacc.default_generators[i]
default_generator.manual_seed(seed)
_lazy_call(cb, seed_all=True)
def seed() -> None:
r"""Sets the seed for generating random numbers to a random number for the current GPU.
It's safe to call this function if VACC is not available; in that
case, it is silently ignored.
.. warning::
If you are working with a multi-GPU model, this function will only initialize
the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
"""
def cb():
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
default_generator.seed()
_lazy_call(cb)
def seed_all() -> None:
r"""Sets the seed for generating random numbers to a random number on all GPUs.
It's safe to call this function if VACC is not available; in that
case, it is silently ignored.
"""
def cb():
random_seed = 0
seeded = False
for i in range(device_count()):
default_generator = torch.vacc.default_generators[i]
if not seeded:
default_generator.seed()
random_seed = default_generator.initial_seed()
seeded = True
else:
default_generator.manual_seed(random_seed)
_lazy_call(cb)
def initial_seed() -> int:
r"""Returns the current random seed of the current GPU.
.. warning::
This function eagerly initializes VACC.
"""
_lazy_init()
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
return default_generator.initial_seed()

327
torch_vacc/vacc/streams.py Normal file
View File

@@ -0,0 +1,327 @@
import ctypes
from typing import Any, Optional
import torch
from packaging import version
from torch._utils import _get_device_index
try:
from torch._streambase import _StreamBase, _EventBase
except ImportError:
# torch <= 2.1
_StreamBase = _EventBase = object
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
from ._device import device
from .lazy_initialize import _lazy_init
# remove torch version arch-suffix(i.e. +cpu)
torch_version = torch.__version__.split('+')[0]
class _StreamCommon:
"""Wrapper around a VACC stream.
A VACC stream is a linear sequence of execution that belongs to a specific
device, independent from other streams.
Args:
device(torch.device or int, optional): a device on which to allocate
the stream. If :attr:`device` is ``None`` (default) or a negative
integer, this will use the current device.
priority(int, optional): priority of the stream. Can be either
-1 (high priority) or 0 (low priority). By default, streams have
priority 0.
"""
def __new__(cls, device=None, priority=0, **kwargs):
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
else:
with torch_vacc.vacc.device(device):
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
def wait_event(self, event):
event.wait(self)
def record_event(self, event=None):
"""Records an event.
Args:
event (torch_vacc.Event, optional): event to record. If not given, a new one
will be allocated.
Returns:
Recorded event.
"""
if event is None:
event = Event()
event.record(self)
return event
def wait_stream(self, stream):
"""Synchronizes with another stream.
All future work submitted to this stream will wait until all kernels
submitted to a given stream at the time of call complete.
Args:
stream (Stream): a stream to synchronize.
"""
self.wait_event(stream.record_event())
def query(self):
return super().query()
def synchronize(self):
super().synchronize()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.vacc_stream)
def __eq__(self, o):
if isinstance(o, Stream):
return super().__eq__(o)
return False
def __hash__(self):
return hash((self.vacc_stream, self.device))
def __repr__(self):
return f"torch_vacc.vacc.Stream device={self.device} vacc_stream={self.vacc_stream:#x}"
if version.parse(torch_version) <= version.parse("2.1"):
# torch <= 2.1
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
pass
elif version.parse(torch_version) < version.parse("2.6"):
# torch < 2.6
class Stream(_torch_vacc._VACCStreamBase, _StreamBase, _StreamCommon):
pass
else:
# torch >= 2.6
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
pass
class _EventCommon:
"""Wrapper around a VACC event.
VACC events are synchronization markers that can be used to monitor the
device's progress, to accurately measure timing, and to synchronize VACC
streams.
The underlying VACC events are lazily initialized when the event is first
recorded or exported to another process. After creation, only streams on the
same device may record the event. However, streams on any device can wait on
the event.
Args:
calc_time (bool, optional): indicates if the event should measure time
(default: ``False``)
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
"""
def __new__(cls, enable_timing=False, blocking=False):
return super(Event, cls).__new__(
cls,
calc_time=enable_timing,
blocking=blocking,
)
def record(self, stream=None):
"""Records the event in a given stream.
Uses ``torch_vacc.vacc.current_stream()`` if no stream is specified. The
stream's device must match the event's device."""
if stream is None:
stream = torch_vacc.vacc.current_stream()
super().record(stream)
def wait(self, stream=None):
"""Makes all future work submitted to the given stream wait for this
event.
Use ``torch_vacc.vacc.current_stream()`` if no stream is specified.
.. note:: This is a wrapper around ``vaccrtStreamWaitEvent()``
"""
if stream is None:
stream = torch_vacc.vacc.current_stream()
super().wait(stream)
def query(self):
"""Checks if all work currently captured by event has completed.
Returns:
A boolean indicating if all work currently captured by event has
completed.
"""
return super().query()
def elapsed_time(self, end_event):
"""Returns the time elapsed in milliseconds after the event was
recorded and before the end_event was recorded.
"""
return super().elapsed_time(end_event)
def synchronize(self):
r"""Waits for the event to complete.
Waits until the completion of all work currently captured in this event.
This prevents the CPU thread from proceeding until the event completes.
.. note:: This is a wrapper around ``vaccEventSynchronize()``.
"""
super().synchronize()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.vacc_event)
def __repr__(self):
if self.vacc_event:
return f"<torch_vacc.vacc.Event {self._as_parameter_.value:#x}>"
else:
return "<torch_vacc.vacc.Event uninitialized>"
if version.parse(torch_version) <= version.parse("2.1"):
# torch <= 2.1
class Event(_torch_vacc._VACCEventBase, _EventCommon):
pass
elif version.parse(torch_version) < version.parse("2.6"):
# torch < 2.6
class Event(_torch_vacc._VACCEventBase, _EventBase, _EventCommon):
pass
else:
# torch >= 2.6
class Event(_torch_vacc._VACCEventBase, _EventCommon):
pass
class StreamContext:
r"""Context-manager that selects a given stream.
All VACC kernels queued within its context will be enqueued on a selected
stream.
Args:
stream (stream): selected stream. This manager is a no-op if it's
``None``.
.. note:: Streams are per-device.
"""
cur_stream: Optional["torch_vacc.vacc.Stream"]
def __init__(self, stream: Optional["torch_vacc.vacc.Stream"]):
self.stream = stream
self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting():
if self.idx is None:
self.idx = -1
self.src_prev_stream = (
None
if not torch.jit.is_scripting()
else torch_vacc.vacc.default_stream(None)
)
self.dst_prev_stream = (
None
if not torch.jit.is_scripting()
else torch_vacc.vacc.default_stream(None)
)
def __enter__(self):
# Local cur_stream variable for type refinement
cur_stream = self.stream
# Return if stream is None or VACC device not available
if cur_stream is None or self.idx == -1:
return
self.src_prev_stream = torch_vacc.vacc.current_stream(None)
# If the stream is not on the current device, then
# set the current stream on the device
if self.src_prev_stream.device != cur_stream.device:
with device(cur_stream.device):
self.dst_prev_stream = torch_vacc.vacc.current_stream(cur_stream.device)
torch_vacc.vacc.set_stream(cur_stream)
def __exit__(self, type: Any, value: Any, traceback: Any):
# Local cur_stream variable for type refinement
cur_stream = self.stream
# If stream is None or no VACC device available, return
if cur_stream is None or self.idx == -1:
return
# Reset the stream on the original device
# and destination device
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
torch_vacc.vacc.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
torch_vacc.vacc.set_stream(self.src_prev_stream) # type: ignore[arg-type]
def stream(stream: Optional["torch_vacc.vacc.Stream"]) -> StreamContext:
r"""Wrapper around the Context-manager StreamContext that
selects a given stream.
Arguments:
stream (Stream): selected stream. This manager is a no-op if it's
``None``.
"""
return StreamContext(stream)
def set_stream(stream: Stream):
r"""Sets the current stream.This is a wrapper API to set the stream.
Usage of this function is discouraged in favor of the ``stream``
context manager.
Args:
stream (Stream): selected stream. This function is a no-op
if this argument is ``None``.
"""
if stream is None:
return
_torch_vacc._vacc_setStream(
stream_id=stream.stream_id,
device_index=stream.device_index,
device_type=stream.device_type,
)
def current_stream(device=None) -> Stream:
r"""Returns the currently selected :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
the currently selected :class:`Stream` for the current device, given
by :func:`~torch_vacc.vacc.current_device`, if :attr:`device` is ``None``
(default).
"""
_lazy_init()
streamdata = _torch_vacc._vacc_getCurrentStream(
_get_device_index(device, optional=True)
)
return Stream(
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
)
def default_stream(device=None) -> Stream:
r"""Returns the default :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
the default :class:`Stream` for the current device, given by
:func:`_torch_vacc.current_device`, if :attr:`device` is ``None``
(default).
"""
_lazy_init()
streamdata = _torch_vacc._vacc_getDefaultStream(
_get_device_index(device, optional=True)
)
return Stream(
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
)

2
torch_vacc/version.py Normal file
View File

@@ -0,0 +1,2 @@
__all__ = ['__version__']
__version__ = '1.3.3.777'

269
torch_vacc/vslog.cfg Normal file
View File

@@ -0,0 +1,269 @@
hot_update: true
- channel: 0
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "$PNAME-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 1
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vacm-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 2
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vace-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 3
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vacl-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 4
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vame-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 5
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vaml-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 6
sync: sync
priority: error
category: 0
category_extend: 0
append_cr: true
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "rt-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: true
out_type: screen
category: 0
category_extend: 0
- channel: 7
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "nn-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: true
out_type: screen
category: 0
category_extend: 0
- channel: 8
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "tm-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 9
sync: sync
priority: error
category: 0
category_extend: 0
append_cr: true
no_prefix: true
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "md-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: true
out_type: screen
category: 0
category_extend: 0
- channel: 10
sync: sync
priority: error
category: 0
category_extend: 0
append_cr: false
no_prefix: true
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "rs-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 11
sync: sync
priority: error
category: 0
category_extend: 0
append_cr: false
no_prefix: true
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vaapi-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 12
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vccl-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: true
out_type: screen
category: 0
category_extend: 0