This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

124
torch_vacc/__init__.py Normal file
View File

@@ -0,0 +1,124 @@
import atexit
import ctypes
import os
import sys
import types
import torch
import torch.distributed
from .version import __version__
def register_runtime_libraries() -> None:
try:
libpython_so = f"libpython{sys.version_info.major}.{sys.version_info.minor}.so"
base_prefix = getattr(sys, "base_prefix", sys.prefix)
if not base_prefix.startswith("/usr"): # like conda or virtualenv
ctypes.CDLL(os.path.join(base_prefix, "lib", libpython_so))
this_path = os.path.dirname(os.path.realpath(__file__))
rt_dll_dpath = os.path.join(this_path, "_vacc_libs")
ctypes.CDLL(os.path.join(rt_dll_dpath, "libodsp.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvaccrt.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvnnl.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvccl.so"))
ctypes.CDLL(os.path.join(rt_dll_dpath, "libvacc_core.so"))
except Exception as e:
raise RuntimeError("Vastai runtime library not loaded.") from e
register_runtime_libraries()
from ._vacc_libs import _torch_vacc as _C
try:
_C._init_torch_vacc_module()
except Exception as e:
raise RuntimeError("Failed to init torch_vacc.") from e
def _apply_patches(monkey_patches):
def _getattr(module_list, root_module=torch):
if len(module_list) <= 1:
return root_module
if hasattr(root_module, module_list[0]):
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
else:
empty_module_name = f"{root_module.__name__}.{module_list[0]}"
sys.modules[empty_module_name] = types.ModuleType(empty_module_name)
setattr(root_module, module_list[0], sys.modules.get(empty_module_name))
return _getattr(module_list[1:], getattr(root_module, module_list[0]))
for patch_pair in monkey_patches:
dest, patch = patch_pair
dest_module = _getattr(dest.split("."), root_module=torch)
last_module_level = dest.split(".")[-1]
if not isinstance(patch, types.ModuleType):
setattr(dest_module, last_module_level, patch)
continue
if not hasattr(dest_module, last_module_level) or not hasattr(patch, "__all__"):
setattr(dest_module, last_module_level, patch)
sys.modules[f"{dest_module.__name__}.{last_module_level}"] = patch
continue
assert hasattr(patch, "__all__"), "Patch module must have __all__ definition."
dest_module = getattr(dest_module, last_module_level)
for attr in patch.__all__:
setattr(dest_module, attr, getattr(patch, attr))
import torch_vacc.vacc as vacc
# register "vacc" module/functions to torch
torch._register_device_module("vacc", vacc)
unsupported_dtype = [
torch.quint8,
torch.quint4x2,
torch.quint2x4,
torch.qint32,
torch.qint8,
]
torch.utils.generate_methods_for_privateuse1_backend(
for_tensor=True,
for_module=True,
for_storage=True, # TODO(qingsong): do we support storage?
unsupported_dtype=unsupported_dtype,
)
# register legacy *DtypeTensor into torch.vacc
_C._initialize_python_bindings()
# init seed generators, vacc default generator
vacc.init()
def is_vccl_available() -> bool:
return True
torch.distributed.is_vccl_available = is_vccl_available
def set_global_log_level(log_level):
_C.set_global_log_level(log_level.upper())
def print_vacc_ops():
_C._print_vacc_ops()
def vacc_ops_list():
return _C._vacc_ops_list().split(",")
def print_vacc_selective_ops():
_C._print_vacc_selective_ops()
def _vacc_shutdown():
_C._vacc_module_shutdown()
atexit.register(_vacc_shutdown)

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
torch_vacc/_vacc_libs/libodsp.so Executable file

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
torch_vacc/_vacc_libs/libvccl.so Executable file

Binary file not shown.

BIN
torch_vacc/_vacc_libs/libvnnl.so Executable file

Binary file not shown.

View File

View File

@@ -0,0 +1,398 @@
import os
import warnings
import logging as logger
from functools import wraps
import torch
import torch_vacc
'''
try:
import torchair
except ImportError:
IS_TORCHAIR_INSTALLED = False
else:
IS_TORCHAIR_INSTALLED = True
'''
warnings.filterwarnings(action="once")
torch_fn_white_list = [
"_cudnn_init_dropout_state",
"_empty_affine_quantized",
"_empty_per_channel_affine_quantized",
"_pin_memory",
"_sparse_coo_tensor_unsafe",
"_sparse_csr_tensor_unsafe",
"logspace",
"randint",
"hann_window",
"rand",
"full_like",
"ones_like",
"rand_like",
"randperm",
"arange",
"frombuffer",
"normal",
"empty_strided",
"empty_like",
"scalar_tensor",
"tril_indices",
"bartlett_window",
"ones",
"sparse_coo_tensor",
"randn",
"kaiser_window",
"tensor",
"triu_indices",
"as_tensor",
"zeros",
"randint_like",
"full",
"eye",
"empty",
"blackman_window",
"zeros_like",
"range",
"sparse_csr_tensor",
"randn_like",
"from_file",
"linspace",
"hamming_window",
"empty_quantized",
"autocast",
"load",
]
torch_tensor_fn_white_list = [
"new_empty",
"new_empty_strided",
"new_full",
"new_ones",
"new_tensor",
"new_zeros",
"to",
]
torch_module_fn_white_list = ["to", "to_empty"]
torch_cuda_fn_white_list = [
"get_device_properties",
"get_device_name",
"get_device_capability",
"list_gpu_processes",
"set_device",
"synchronize",
"mem_get_info",
"memory_stats",
"memory_summary",
"memory_allocated",
"max_memory_allocated",
"reset_max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"reset_max_memory_cached",
"reset_peak_memory_stats",
"current_stream",
"default_stream",
]
torch_profiler_fn_white_list = ["profile"]
torch_distributed_fn_white_list = ["__init__"]
device_kwargs_list = ["device", "device_type", "map_location"]
def wrapper_cuda(fn):
@wraps(fn)
def decorated(*args, **kwargs):
replace_int = fn.__name__ in ["to", "to_empty"]
if args:
args_new = list(args)
args = replace_cuda_to_vacc_in_list(args_new, replace_int)
if kwargs:
for device_arg in device_kwargs_list:
device = kwargs.get(device_arg, None)
if device is not None:
replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device)
device_ids = kwargs.get("device_ids", None)
if type(device_ids) == list:
device_ids = replace_cuda_to_vacc_in_list(device_ids, replace_int)
return fn(*args, **kwargs)
return decorated
def replace_cuda_to_vacc_in_kwargs(kwargs, device_arg, device):
if type(device) == str and "cuda" in device:
kwargs[device_arg] = device.replace("cuda", "vacc")
elif type(device) == torch.device and "cuda" in device.type:
device_info = (
"vacc:{}".format(device.index) if device.index is not None else "vacc"
)
kwargs[device_arg] = torch.device(device_info)
elif type(device) == int:
kwargs[device_arg] = f"vacc:{device}"
elif type(device) == dict:
kwargs[device_arg] = replace_cuda_to_vacc_in_dict(device)
def replace_cuda_to_vacc_in_list(args_list, replace_int):
for idx, arg in enumerate(args_list):
if isinstance(arg, str) and "cuda" in arg:
args_list[idx] = arg.replace("cuda", "vacc")
elif isinstance(arg, torch.device) and "cuda" in arg.type:
device_info = (
"vacc:{}".format(arg.index) if arg.index is not None else "vacc"
)
args_list[idx] = torch.device(device_info)
elif replace_int and not isinstance(arg, bool) and isinstance(arg, int):
args_list[idx] = f"vacc:{arg}"
elif isinstance(arg, dict):
args_list[idx] = replace_cuda_to_vacc_in_dict(arg)
return args_list
def replace_cuda_to_vacc_in_dict(device_dict):
new_dict = {}
for key, value in device_dict.items():
if isinstance(key, str):
key = key.replace("cuda", "vacc")
if isinstance(value, str):
value = value.replace("cuda", "vacc")
new_dict[key] = value
return new_dict
def device_wrapper(enter_fn, white_list):
for fn_name in white_list:
fn = getattr(enter_fn, fn_name, None)
if fn:
setattr(enter_fn, fn_name, wrapper_cuda(fn))
def wrapper_vccl(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if args:
args_new = list(args)
for idx, arg in enumerate(args_new):
if type(arg) == str and "nccl" in arg:
args_new[idx] = arg.replace("nccl", "vccl")
args = args_new
if kwargs:
if type(kwargs.get("backend", None)) == str:
kwargs["backend"] = "vccl"
return fn(*args, **kwargs)
return decorated
def wrapper_data_loader(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if kwargs:
pin_memory = kwargs.get("pin_memory", False)
pin_memory_device = kwargs.get("pin_memory_device", None)
if pin_memory and not pin_memory_device:
kwargs["pin_memory_device"] = "vacc"
if (
pin_memory
and type(pin_memory_device) == str
and "cuda" in pin_memory_device
):
kwargs["pin_memory_device"] = pin_memory_device.replace("cuda", "vacc")
return fn(*args, **kwargs)
return decorated
def wrapper_get_available_device_type(fn):
@wraps(fn)
def decorated(*args, **kwargs):
try:
if (torch.vacc.is_available()):
return 'vacc'
except Exception as e:
msg = "vacc device is not available."
warnings.warn(msg, RuntimeWarning)
return fn(*args, **kwargs)
return decorated
'''
def wrapper_profiler(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if kwargs:
if (
"experimental_config" in kwargs.keys()
and type(kwargs.get("experimental_config"))
!= torch_vacc.profiler._ExperimentalConfig
):
logger.warning(
"The parameter experimental_config of torch.profiler.profile has been deleted by the tool "
"because it can only be used in cuda, please manually modify the code "
"and use the experimental_config parameter adapted to vacc."
)
del kwargs["experimental_config"]
return fn(*args, **kwargs)
return decorated
def wrapper_compile(fn):
@wraps(fn)
def decorated(*args, **kwargs):
vacc_backend = torchair.get_vacc_backend()
if kwargs:
backend = kwargs.get("backend", None)
if (
not backend
or not isinstance(backend, functools.partial)
or not isinstance(backend.func, type(vacc_backend.func))
):
kwargs["backend"] = vacc_backend
else:
kwargs["backend"] = vacc_backend
return fn(*args, **kwargs)
return decorated
'''
def jit_script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None):
msg = "torch.jit.script will be disabled by transfer_to_vacc, which currently does not support it."
warnings.warn(msg, RuntimeWarning)
return obj
def patch_cuda():
patchs = [
["cuda", torch_vacc.vacc],
["cuda.amp", torch_vacc.vacc.amp],
["cuda.random", torch_vacc.vacc.random],
["cuda.amp.autocast_mode", torch_vacc.vacc.amp.autocast_mode],
["cuda.amp.common", torch_vacc.vacc.amp.common],
["cuda.amp.grad_scaler", torch_vacc.vacc.amp.grad_scaler],
]
torch_vacc._apply_patches(patchs)
'''
def patch_profiler():
patchs = [
["profiler.profile", torch_vacc.profiler.profile],
["profiler.schedule", torch_vacc.profiler.schedule],
[
"profiler.tensorboard_trace_handler",
torch_vacc.profiler.tensorboard_trace_handler,
],
["profiler.ProfilerAction", torch_vacc.profiler.ProfilerAction],
["profiler.ProfilerActivity.CUDA", torch_vacc.profiler.ProfilerActivity.VACC],
["profiler.ProfilerActivity.CPU", torch_vacc.profiler.ProfilerActivity.CPU],
]
torch_vacc._apply_patches(patchs)
'''
def warning_fn(msg, rank0=True):
is_distributed = (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
)
env_rank = os.getenv("RANK", None)
if rank0 and is_distributed:
if torch.distributed.get_rank() == 0:
warnings.warn(msg, ImportWarning)
elif rank0 and env_rank:
if env_rank == "0":
warnings.warn(msg, ImportWarning)
else:
warnings.warn(msg, ImportWarning)
def init():
warning_fn(
"""
*************************************************************************************************************
The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.vacc and torch.nn.Module.vacc now..
The torch.cuda.DoubleTensor is replaced with torch.vacc.FloatTensor cause the double type is not supported now..
The backend in torch.distributed.init_process_group set to vccl now..
The torch.cuda.* and torch.cuda.amp.* are replaced with torch.vacc.* and torch.vacc.amp.* now..
The device parameters have been replaced with vacc in the function below:
{}
If you notices any functions you use is not included in the above list, feel free to contact torch-vacc development team.
*************************************************************************************************************
""".format(
", ".join(
["torch." + i for i in torch_fn_white_list]
+ ["torch.Tensor." + i for i in torch_tensor_fn_white_list]
+ ["torch.nn.Module." + i for i in torch_module_fn_white_list]
)
)
)
# torch.cuda.*
patch_cuda()
device_wrapper(torch.cuda, torch_cuda_fn_white_list)
# torch.profiler.*
# TODO(qingsong): profiler not implemented yet
# patch_profiler()
# device_wrapper(torch.profiler, torch_profiler_fn_white_list)
# torch.*
device_wrapper(torch, torch_fn_white_list)
# torch.Tensor.*
device_wrapper(torch.Tensor, torch_tensor_fn_white_list)
torch.Tensor.cuda = torch.Tensor.vacc
torch.Tensor.is_cuda = torch.Tensor.is_vacc
for dtype_tensor in [
"ByteTensor",
"CharTensor",
"DoubleTensor",
"FloatTensor",
"IntTensor",
"LongTensor",
"ShortTensor",
"HalfTensor",
"BoolTensor",
]:
setattr(
torch.cuda,
dtype_tensor,
getattr(torch.vacc, dtype_tensor),
)
# TODO(qingsong): do we need this? should we add LongTensor=IntTensor?
torch.cuda.DoubleTensor = torch.vacc.FloatTensor
# torch.nn.Module.*
device_wrapper(torch.nn.Module, torch_module_fn_white_list)
torch.nn.Module.cuda = torch.nn.Module.vacc
# torch.distributed.init_process_group
torch.distributed.init_process_group = wrapper_vccl(
torch.distributed.init_process_group
)
torch.distributed.is_nccl_available = torch.distributed.is_vccl_available
# torch.nn.parallel.DistributedDataParallel
device_wrapper(
torch.nn.parallel.DistributedDataParallel, torch_distributed_fn_white_list
)
# torch.utils.data.DataLoader
torch.utils.data.DataLoader.__init__ = wrapper_data_loader(
torch.utils.data.DataLoader.__init__
)
torch.jit.script = jit_script
torch._utils._get_available_device_type = wrapper_get_available_device_type(
torch._utils._get_available_device_type
)
'''
if IS_TORCHAIR_INSTALLED:
torch.compile = wrapper_compile(torch.compile)
'''
init()

View File

View File

@@ -0,0 +1,42 @@
import torch
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
output, rsigma, var = torch.ops.vacc.rms_norm_forward(input, weight, eps)
ctx.save_for_backward(input, weight, rsigma, var)
ctx.eps = eps
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
input, weight, rsigma, var = ctx.saved_tensors
grad_input, grad_weight = _torch_vacc.rms_norm_backward(
grad_output, input, weight, rsigma, var, ctx.eps
)
return grad_input, grad_weight, None
def rms_norm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
return FusedRMSNormFunction.apply(input, weight, eps)
class FusedRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6):
super(FusedRMSNorm, self).__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
output = FusedRMSNormFunction.apply(hidden_states, self.weight, self.eps)
output = output.to(dtype)
return output

View File

@@ -0,0 +1,32 @@
import torch
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
class FusedRopeEmbFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q: torch.Tensor, k: torch.Tensor, offset: int):
qemb, kemb = _torch_vacc.rope_forward(q, k, offset)
ctx.offset = offset
return qemb, kemb
@staticmethod
def backward(ctx, q_out_grad: torch.Tensor, k_out_grad: torch.Tensor):
grad_input, grad_rope = _torch_vacc.rope_backward(
q_out_grad, k_out_grad, ctx.offset
)
return grad_input, grad_rope, None
def rope_emb(q: torch.Tensor, k: torch.Tensor, offset: int):
# return FusedRopeEmbFunction.apply(q, k, offset)
return torch_vacc.vacc.custom_ops.RotaryPosEmbedding(q=q, k=k, offset=offset)
class RopeEmb(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, q: torch.Tensor, k: torch.Tensor, offset: int):
return rope_emb(q, k, offset)

View File

@@ -0,0 +1,62 @@
from contextlib import contextmanager
import os
import sys
import torch
from torch.testing import make_tensor
from functools import partial, wraps
import torch.testing._internal.common_device_type as cdt
from torch.testing._internal.common_device_type import (
DeviceTypeTestBase,
dtypes,
instantiate_device_type_tests,
onlyOn,
onlyPRIVATEUSE1,
ops,
)
if sys.version_info > (3, 8):
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
init_multigpu_helper,
skip_if_lt_x_gpu,
get_timeout,
#skip_if_rocm,
with_dist_debug_levels,
)
else:
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
init_multigpu_helper,
skip_if_lt_x_gpu,
get_timeout,
skip_if_rocm,
with_dist_debug_levels,
)
from torch.testing._internal.common_utils import (
TestCase,
load_tests,
parametrize,
run_tests,
subtest,
retry_on_connect_failures,
instantiate_parametrized_tests,
)
onlyVacc = onlyPRIVATEUSE1
class VaccTestBase(DeviceTypeTestBase):
device_type = "vacc"
if VaccTestBase not in cdt.device_type_test_bases:
cdt.device_type_test_bases.append(VaccTestBase)
@contextmanager
def freeze_rng_state():
rng_state = torch.get_rng_state()
yield
torch.set_rng_state(rng_state)

View File

@@ -0,0 +1,103 @@
"""
Tool to summarize unit test XML reports, it summarize
* number of tests, and failure/error/skipped
* top 10 slowest tests
Usage:
python -m torch_vacc.testing.summarize_report --report report.xml
"""
import argparse
from dataclasses import dataclass
from xml.etree import ElementTree as ET
import sys
from torch_vacc import set_global_log_level
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--report", type=str)
return parser.parse_args()
def summarize_testsuites(suites):
summary = {
"errors": int,
"failures": int,
"skipped": int,
"skips": int,
"tests": int,
"time": float,
}
attribs = [s.attrib for s in suites]
for key in summary:
summary[key] = sum(summary[key](a[key]) for a in attribs if key in a)
assert not (summary["skipped"] and summary["skips"])
if summary["skips"]:
summary["skipped"] = summary["skips"]
return summary
def format_summary(summary):
template = "Ran {tests} tests in {time:.3f}s (errors={errors}, failures={failures}, skipped={skipped})"
msg = template.format(**summary)
if summary["errors"] > 0 or summary["failures"] > 0:
msg = "FAILED. " + msg
return msg
@dataclass
class TestCaseInfo:
test_class_name: str
test_name: str
time: float
timestamp: str
success: bool
def __lt__(self, other):
return self.time < other.time
def sort_cases_by_time(suites):
test_cases = [
TestCaseInfo(
s.attrib["classname"],
s.attrib["name"],
s.attrib["time"],
s.attrib["timestamp"],
s.attrib.get("failure") is None,
)
for s in suites
]
test_cases.sort(reverse=True)
return test_cases
def read_report(fpath):
with open(fpath) as report:
try:
report = ET.parse(report)
except ET.ParseError:
print(f"{sys.argv[0]}: Cannot parse file {fpath}", file=sys.stderr)
return
root = report.getroot()
suites = [root] if root.tag == "testsuite" else root.findall("testsuite")
summary = summarize_testsuites(suites)
summary_msg = format_summary(summary)
print(summary_msg)
for suite in suites:
cases = sort_cases_by_time(suite.findall("testcase"))
[print(case) for case in cases[:10]]
def main():
set_global_log_level("ERROR")
args = parse_args()
read_report(args.report)
if __name__ == "__main__":
main()

184
torch_vacc/vacc/__init__.py Normal file
View File

@@ -0,0 +1,184 @@
from __future__ import annotations
from typing import Tuple
import torch
from ._device import (
current_device,
device,
device_count,
get_device_capability,
get_device_name,
get_device_properties,
is_available,
is_bf16_supported,
set_device,
synchronize,
)
from .amp import (
get_amp_supported_dtype,
get_autocast_dtype,
is_autocast_enabled,
set_autocast_dtype,
set_autocast_enabled,
)
from .lazy_initialize import _is_in_bad_fork, _lazy_call, _lazy_init
from .memory import ( # caching_allocator_alloc,; caching_allocator_delete,
empty_cache,
get_allocator_backend,
max_memory_allocated,
max_memory_cached,
max_memory_reserved,
mem_get_info,
memory_allocated,
memory_cached,
memory_reserved,
memory_snapshot,
memory_stats,
memory_stats_as_nested_dict,
memory_summary,
reset_accumulated_memory_stats,
reset_max_memory_allocated,
reset_max_memory_cached,
reset_peak_memory_stats,
set_per_process_memory_fraction,
)
from .streams import Event, Stream, current_stream, default_stream, set_stream, stream
def init():
r"""Initialize PyTorch's VACC state. You may need to call
this explicitly if you are interacting with PyTorch via
its C API, as Python bindings for VACC functionality will not
be available until this initialization takes place. Ordinary users
should not need this, as all of PyTorch's VACC methods
automatically initialize VACC state on-demand.
Does nothing if the VACC state is already initialized.
"""
_lazy_init()
# default_generators is empty util _lazy_init() is called
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
from .custom_ops import *
from .custom_qwen3_ops import *
from .random import * # noqa: F403
__all__ = [
"device",
"is_available",
"is_bf16_supported",
"current_device",
"set_device",
"device_count",
"get_device_properties",
"get_device_name",
"get_device_capability",
"synchronize",
"amp",
"get_amp_supported_dtype",
"is_autocast_enabled",
"set_autocast_enabled",
"get_autocast_dtype",
"set_autocast_dtype",
"_is_in_bad_fork",
"_lazy_call",
"get_rng_state",
"get_rng_state_all",
"set_rng_state",
"set_rng_state_all",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"initial_seed",
"set_stream",
"current_stream",
"default_generators",
"default_stream",
"stream",
"Stream",
"Event",
"mem_get_info",
"set_per_process_memory_fraction",
"empty_cache",
"memory_stats",
"memory_stats_as_nested_dict",
"reset_accumulated_memory_stats",
"reset_peak_memory_stats",
"reset_max_memory_allocated",
"reset_max_memory_cached",
"memory_allocated",
"max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"memory_cached",
"max_memory_cached",
"memory_snapshot",
"memory_summary",
"get_allocator_backend",
"rms_norm",
"RotaryPosEmbedding",
"scaled_dot_product_attention",
"scaled_dot_product_attention_cp_forward",
"scaled_dot_product_attention_cp_backward",
"swiglu",
"paged_attention",
"reshape_and_cache_attention",
"concat_and_cache_attention",
"w8a8_block_fp8_matmul",
"moe_expert_token_group_reassign",
"fused_mlp_mm_fp8",
"fused_mlp_fp8",
"fused_moe_preprocess",
"fused_residual_rmsnorm",
"parallel_embedding",
"all_reduce",
"all_gather",
"broadcast",
"fused_mlp_moe_with_rmsnorm",
"fuse_moe_decode_v2_allreduce",
"topk_topp",
"fused_mla",
"fused_mla_allreduce",
"fused_mlp_with_rmsnorm",
"fused_mlp_allreduce",
"ds3_sampler",
"sampler_v1",
"rejection_sampler",
"rejection_sampler_update_hidden_states",
"rejection_sampler_v1",
"fused_matmul_allgather",
"fused_mla_v2",
"fused_mla_allreduce_v2",
"mla_matmul_scale",
"mla_matmul",
"fused_mla_prefill_stage0",
"fused_mla_prefill_stage1",
"fused_mla_prefill_stage0_allreduce",
"fuse_moe_prefill_stage0",
"fuse_mla_mlp_v2_allreduce_decode",
"fuse_mla_moe_v2_allreduce_decode",
"fuse_mla_mlp_v2_allreduce_decode_layers",
"fuse_mla_moe_v2_allreduce_decode_layers",
"fuse_mla_mlp_v2_allreduce_decode_layers_v2",
"fuse_mla_moe_v2_allreduce_decode_layers_v2",
"fuse_mlp_qwen_int4",
"fuse_mlp_qwen_int4_reduce",
"w4a8_block_int4_matmul",
"fuse_atten_qwen3",
"fuse_atten_qwen2",
"qwen3_fuse_attention_moe_decode",
"fuse_mtp_stage0",
"fuse_mtp_allreduce",
"roll_out",
"fused_experts_int4_prefill",
"fuse_bge_embedding_stage1",
"l2_norm",
"fuse_mlp_vision",
"patch_merger_vision",
"fuse_atten_vit",
"apply_penalties",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

106
torch_vacc/vacc/_device.py Normal file
View File

@@ -0,0 +1,106 @@
# Device information
# replacing `torch.cuda.func`` with `torch_vacc.vacc.func`.
# see https://pytorch.org/docs/stable/cuda.html
from typing import Any
import warnings
import torch
import torch_vacc
from torch._utils import _get_device_index
from torch_vacc._vacc_libs import _torch_vacc
from .lazy_initialize import _lazy_init
if hasattr(_torch_vacc, "_exchange_device"):
_exchange_device = _torch_vacc._exchange_device
else:
def _exchange_device(device: int) -> int:
return _torch_vacc._exchange_device()
if device < 0:
return -1
prev_device = current_device()
if device != prev_device:
set_device(device)
return prev_device
class device(object):
"""Context-manager that changes the selected device.
Args:
device (torch.device or int): device index to select. It's a no-op if
this argument is a negative integer or ``None``.
"""
def __init__(self, device: Any):
self.idx = _get_device_index(device, optional=True)
self.prev_idx = -1
def __enter__(self):
self.prev_idx = _exchange_device(self.idx)
def __exit__(self, *args):
_exchange_device(self.prev_idx)
return False
def is_available() -> bool:
r"""Returns whether vacc is available."""
return device_count() > 0
def is_bf16_supported() -> bool:
r"""Returns a bool indicating if the current vacc device supports dtype bfloat16"""
return True
def current_device() -> int:
r"""Returns the index of a currently selected vacc device."""
_lazy_init()
return _torch_vacc._current_device()
def set_device(device: torch.device):
device_index = _get_device_index(device, optional=True)
if device_index >= 0:
_torch_vacc._set_device(device_index)
def get_device_capability(device=None):
r"""Query the minor and major data of device. Cann does not
have a corresponding concept and is not supported. By default, it returns None
"""
_infos = "torch.vacc.get_device_capability isn't implemented! Please do the version check in other ways, Unlike CUDA major,min"
raise AssertionError(_infos)
def get_device_name(device_name=None):
device_id = _get_device_index(device_name, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id")
_lazy_init()
device_prop = _torch_vacc._vacc_getDeviceProperties(device_id)
return device_prop.name
def get_device_properties(device_name=None):
device_id = _get_device_index(device_name, optional=True)
if device_id < 0 or device_id >= device_count():
raise AssertionError("Invalid device id")
_lazy_init()
return _torch_vacc._vacc_getDeviceProperties(device_id)
def device_count():
r"""Returns the number of available vacc devices"""
return _torch_vacc._device_count()
def synchronize(device=None) -> None:
"""Waits for all operations in all streams on a VACC device to complete."""
_lazy_init()
with torch_vacc.vacc.device(device):
return _torch_vacc._device_synchronize()
# Memory management (https://pytorch.org/docs/stable/cuda.html#memory-management)

View File

@@ -0,0 +1,26 @@
from typing import List
import torch
from torch_vacc._vacc_libs import _torch_vacc
from .grad_scaler import OptState, GradScaler
from .autocast_mode import autocast, custom_fwd, custom_bwd
def get_amp_supported_dtype() -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
def is_autocast_enabled() -> bool:
return _torch_vacc.is_autocast_enabled()
def set_autocast_enabled(enable: bool):
_torch_vacc.set_autocast_enabled(enable)
def get_autocast_dtype() -> torch.dtype:
return _torch_vacc.get_autocast_dtype()
def set_autocast_dtype(dtype: torch.dtype):
return _torch_vacc.set_autocast_dtype(dtype)

Binary file not shown.

View File

@@ -0,0 +1,144 @@
import collections
import functools
from typing import Any
import torch
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
class autocast(torch.amp.autocast_mode.autocast):
r"""See :class:`torch.autocast`.
``torch.vacc.amp.autocast(args...)`` is equivalent to ``torch.autocast("vacc", args...)``
"""
def __init__(
self,
enabled: bool = True,
dtype: torch.dtype = torch.float16,
cache_enabled: bool = True,
):
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = "vacc"
self.fast_dtype = dtype
return
super().__init__(
"vacc", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
)
def __enter__(self):
if torch._jit_internal.is_scripting():
return self
return super().__enter__()
# TODO: discuss a unified TorchScript-friendly API for autocast
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
return
return super().__exit__(exc_type, exc_val, exc_tb)
def __call__(self, func):
if torch._jit_internal.is_scripting():
return func
return super().__call__(func)
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
# may be falsely detected as "Iterables."
def _cast(value, dtype):
if isinstance(value, torch.Tensor):
is_eligible = (
value.is_floating_point()
and value.is_vacc
and (value.dtype is not torch.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value
# custom_fwd is a decorator that may or may not be used with arguments, following
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
# this works:
# @custom_fwd
# def forward(...):
# this also works:
# @custom_fwd(cast_inputs=torch.float)
# def forward(...):
def custom_fwd(fwd=None, *, cast_inputs=None):
"""
Create a helper decorator for ``forward`` methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point VACC Tensors to the target dtype (non-floating-point Tensors are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.
.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
"""
if fwd is None:
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.get_autocast_gpu_dtype()
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
return fwd(*args, **kwargs)
else:
autocast_context = torch.is_autocast_enabled()
args[0]._fwd_used_autocast = False
if autocast_context:
with autocast(enabled=False):
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
else:
return fwd(*args, **kwargs)
return decorate_fwd
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
# cast_inputs supplied to custom_fwd.
def custom_bwd(bwd):
"""Create a helper decorator for backward methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
Ensures that ``backward`` executes with the same autocast state as ``forward``.
See the :ref:`example page<amp-custom-examples>` for more detail.
"""
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
return bwd(*args, **kwargs)
return decorate_bwd

View File

@@ -0,0 +1,7 @@
import torch
__all__ = ["amp_definitely_not_available"]
def amp_definitely_not_available():
return not torch.vacc.is_available()

View File

@@ -0,0 +1,667 @@
import inspect
import warnings
from collections import abc, defaultdict
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Tuple
import torch
from .common import amp_definitely_not_available
__all__ = ["OptState", "GradScaler"]
class _MultiDeviceReplicator:
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert (
master_tensor.is_cuda
or master_tensor.device.type == "xla"
or master_tensor.device.type == "vacc"
)
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,
# as well as associated "enum" values. Prefers defining these at top level because
# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.
# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler
# causes a circular reference, which we'd rather avoid.
class OptState(Enum):
READY = 0
UNSCALED = 1
STEPPED = 2
def _refresh_per_optimizer_state():
return {"stage": OptState.READY, "found_inf_per_device": {}}
class GradScaler:
_scale: Optional[torch.Tensor]
_grows_tracker: Optional[torch.Tensor]
_per_optimizer_states: Dict[int, Dict[str, Any]]
"""
An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling
conveniently.
* ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor.
* ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``.
* ``scaler.update()`` updates ``scaler``'s scale factor.
Example::
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage
(along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty,
and multiple losses/optimizers.
``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow,
a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if
the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used
without incurring inf or NaN gradient values.
``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every
``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`).
* If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params
themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``.
* If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual.
If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by
``growth_factor``.
The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its
value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these
iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).
Args:
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
Default: ``True``
"""
def __init__(
self,
init_scale=2.0**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True,
):
if enabled and amp_definitely_not_available():
warnings.warn(
"torch.vacc.amp.GradScaler is enabled, but VACC device is not available. Disabling."
)
self._enabled = False
else:
self._enabled = enabled
if self._enabled:
assert growth_factor > 1.0, "The growth factor must be > 1.0."
assert backoff_factor < 1.0, "The backoff factor must be < 1.0."
self._init_scale = init_scale
# self._scale will be lazily initialized during the first call to scale()
self._scale = None
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._init_growth_tracker = 0
# self._growth_tracker will be lazily initialized during the first call to scale()
self._growth_tracker = None
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _check_scale_growth_tracker(
self, funcname
) -> Tuple[torch.Tensor, torch.Tensor]:
fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."
assert self._scale is not None, (
f"Attempted {funcname} but _scale is None. " + fix
)
assert self._growth_tracker is not None, (
f"Attempted {funcname} but _growth_tracker is None. " + fix
)
return (self._scale, self._growth_tracker)
def _lazy_init_scale_growth_tracker(self, dev):
assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
self._scale = torch.full((), self._init_scale, dtype=torch.float32, device=dev)
self._growth_tracker = torch.full(
(), self._init_growth_tracker, dtype=torch.int32, device=dev
)
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Args:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert (
outputs.is_cuda
or outputs.device.type == "xla"
or outputs.device.type == "vacc"
)
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[
_MultiDeviceReplicator
] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val):
if isinstance(val, torch.Tensor):
assert (
val.is_cuda or val.device.type == "xla" or val.device.type == "vacc"
)
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, abc.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, (list, tuple)):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16):
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# TODO: is there a way to split by device and dtype without appending in the inner loop?
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get(device),
per_device_inv_scale.get(device),
)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. note::
:meth:`unscale_` does not incur a CPU-GPU sync.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED:
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.double().reciprocal().float()
found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs):
retval = None
if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
retval = optimizer.step(*args, **kwargs)
return retval
def step(self, optimizer, *args, **kwargs):
"""
:meth:`step` carries out the following two operations:
1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer``
earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs.
2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled
gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params.
``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``.
Returns the return value of ``optimizer.step(*args, **kwargs)``.
Args:
optimizer (torch.optim.Optimizer): Optimizer that applies the gradients.
args: Any arguments.
kwargs: Any keyword arguments.
.. warning::
Closure use is not currently supported.
"""
if not self._enabled:
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
raise RuntimeError(
"Closure use is not currently supported if GradScaler is enabled."
)
self._check_scale_growth_tracker("step")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError(
"step() has already been called since the last update()."
)
retval = None
if (
hasattr(optimizer, "_step_supports_amp_scaling")
and optimizer._step_supports_amp_scaling
):
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
# it can query its own state, invoke unscale_ on itself, etc
# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument
# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`
# and `found_inf` to the passed optimizer so that the optimizer can utilize those
# to skip the parameter updates or unscale gradients before updating parameters in
# the fused kernel, e.g. `FusedAdamMathFunctor`.
# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,
# while the method is expected to be called by users side, i.e. their optimizers.
kwargs_ = kwargs
has_grad_scaler_kwarg = (
"grad_scaler" in inspect.signature(optimizer.step).parameters
)
if has_grad_scaler_kwarg:
warnings.warn(
"GradScaler is going to stop passing itself as a keyword argument to the passed "
"optimizer. In the near future GradScaler registers `grad_scale: Tensor` and "
"`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",
FutureWarning,
)
kwargs_.update({"grad_scaler": self})
else:
if optimizer_state["stage"] is OptState.READY:
self._check_inf_per_device(optimizer)
scaler = self._get_scale_async()
found_inf = cast(
torch.Tensor,
sum(
[
t.to(scaler.device, non_blocking=True)
for t in optimizer_state["found_inf_per_device"].values()
]
),
)
optimizer.grad_scale = (
None if optimizer_state["stage"] == OptState.UNSCALED else scaler
)
optimizer.found_inf = found_inf
retval = optimizer.step(*args, **kwargs_)
optimizer_state["stage"] = OptState.STEPPED
if not has_grad_scaler_kwarg:
del optimizer.grad_scale
del optimizer.found_inf
return retval
if optimizer_state["stage"] is OptState.READY:
self.unscale_(optimizer)
assert (
len(optimizer_state["found_inf_per_device"]) > 0
), "No inf checks were recorded for this optimizer."
retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
optimizer_state["stage"] = OptState.STEPPED
return retval
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.vacc.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
.. warning::
For performance reasons, we do not check the scale factor value to avoid synchronizations,
so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or
you are seeing NaNs in your gradients or loss, something is likely wrong. For example,
bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.vacc.FloatTensor with requires_grad=False."
# assert isinstance(new_scale, torch.vacc.FloatTensor), reason # type: ignore[attr-defined]
assert (
isinstance(new_scale, torch.Tensor)
and new_scale.dtype == torch.float32
), reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device=_scale.device, non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
torch._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def _get_scale_async(self):
return self._scale
def get_scale(self):
"""
Returns a Python float containing the current scale, or 1.0 if scaling is disabled.
.. warning::
:meth:`get_scale` incurs a CPU-GPU sync.
"""
if self._enabled:
return (
self._init_scale
if self._scale is None
else self._get_scale_async().item()
)
else:
return 1.0
def get_growth_factor(self):
r"""
Returns a Python float containing the scale growth factor.
"""
return self._growth_factor
def set_growth_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale growth factor.
"""
self._growth_factor = new_factor
def get_backoff_factor(self):
r"""
Returns a Python float containing the scale backoff factor.
"""
return self._backoff_factor
def set_backoff_factor(self, new_factor):
r"""
Args:
new_scale (float): Value to use as the new scale backoff factor.
"""
self._backoff_factor = new_factor
def get_growth_interval(self):
r"""
Returns a Python int containing the growth interval.
"""
return self._growth_interval
def set_growth_interval(self, new_interval):
r"""
Args:
new_interval (int): Value to use as the new growth interval.
"""
self._growth_interval = new_interval
def _get_growth_tracker(self):
if self._enabled:
return (
self._init_growth_tracker
if self._growth_tracker is None
else self._growth_tracker.item()
)
else:
return 0
def is_enabled(self):
r"""
Returns a bool indicating whether this instance is enabled.
"""
return self._enabled
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return (
{
"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker(),
}
if self._enabled
else {}
)
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
if not self._enabled:
return
if len(state_dict) == 0:
raise RuntimeError(
"The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler."
)
self._init_scale = state_dict["scale"]
if self._scale is not None:
self._scale.fill_(state_dict["scale"])
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._growth_interval = state_dict["growth_interval"]
self._init_growth_tracker = state_dict["_growth_tracker"]
if self._growth_tracker is not None:
self._growth_tracker.fill_(state_dict["_growth_tracker"])
def __getstate__(self):
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, (
"A GradScaler instance may only be pickled at the beginning "
"of an iteration, or at the end after scaler.update()."
)
# Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily.
state["_init_scale"] = self.get_scale()
state["_init_growth_tracker"] = self._get_growth_tracker()
state["_scale"] = None
state["_growth_tracker"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
def _check_inf_per_device(self, optimizer):
_scale, _ = self._check_scale_growth_tracker("_check_inf_per_device")
dummy_inv_scale = torch.full((), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((), 0.0, dtype=torch.float32, device=_scale.device)
self._per_optimizer_states[id(optimizer)][
"found_inf_per_device"
] = self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
def _found_inf_per_device(self, optimizer):
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,306 @@
from typing import Tuple, Union, Optional, List
import torch
from torch.nn import functional as F
def split_last_two_dims_into_blocks(x, h, w):
leading_dims = x.shape[:-2]
H, W = x.shape[-2:]
assert (
H % h == 0 and W % w == 0
), "The last two dimensions must be divisible by block size."
x_reshaped = x.view(-1, 1, H, W)
unfolded = F.unfold(x_reshaped, kernel_size=(h, w), stride=(h, w))
unfolded = unfolded.view(-1, 1, h, w, H // h, W // w)
unfolded = unfolded.permute(0, 1, 4, 5, 2, 3)
final_shape = leading_dims + (H // h, W // w, h, w)
result = unfolded.view(final_shape)
return result
def merge_blocks_to_original_layout(x, h, w):
leading_dims = x.shape[:-4]
H_div_h, W_div_w, h, w = x.shape[-4:]
H = H_div_h * h
W = W_div_w * w
x_reshaped = x.view(-1, 1, H_div_h, W_div_w, h, w)
x_reshaped = x_reshaped.permute(0, 1, 4, 5, 2, 3)
x_reshaped = x_reshaped.view(-1, h * w, H_div_h * W_div_w)
folded = F.fold(x_reshaped, output_size=(H, W), kernel_size=(h, w), stride=(h, w))
final_shape = leading_dims + (H, W)
result = folded.view(final_shape)
return result
def w8a8_block_fp8_matmul(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: Optional[torch.Tensor],
weight_scale: Optional[torch.Tensor],
block_size: List[int],
is_linear_weight: bool = False,
output_opt: Optional[torch.Tensor] = None,
**kwargs
):
b0, b1 = block_size
dim0, dim1 = weight.shape
dim0pad, dim1pad = 0, 0
if dim0 % b0 != 0:
dim0pad = b0 - dim0 % b0
if dim1 % b1 != 0:
dim1pad = b1 - dim1 % b1
dim0_origin, dim1_origin = dim0, dim1
dim0 += dim0pad
dim1 += dim1pad
bs0, bs1 = dim0 // b0, dim1 // b1
weight_dequant = torch.nn.functional.pad(weight, (0, dim1pad, 0, dim0pad), value=0)
weight_dequant = weight_dequant.cpu().view(bs0, b0, bs1, b1).permute(
0, 2, 1, 3
).reshape(bs0, bs1, -1).float().to(input.device) * weight_scale.unsqueeze(-1)
weight_dequant = (
weight_dequant.reshape(bs0, bs1, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(input.dtype)
)
weight_dequant = weight_dequant[:dim0_origin, :dim1_origin]
output = torch.matmul(
input, weight_dequant.T if is_linear_weight else weight_dequant
)
if output_opt is not None:
output = output_opt.copy_(output)
return output
def w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: Optional[torch.Tensor],
weight_scale: Optional[torch.Tensor],
block_size: List[int],
**kwargs
):
assert input_scale is None, "w8a8_block_fp8_matmul only support quant weight now"
return w8a8_block_fp8_matmul(
input, weight, None, weight_scale, block_size, is_linear_weight=True
)
def fused_experts(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = True,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
decode_with_batch: bool = False,
) -> torch.Tensor:
batch_seq_all, hidden_dims = hidden_states.shape
intermediate_size = w2_weight.shape[-1]
num_experts = w13_weight.shape[0]
w13_weight = w13_weight.contiguous()
w2_weight = w2_weight.contiguous()
w13_scale = w13_scale.contiguous()
w2_scale = w2_scale.contiguous()
final_hidden_states = torch.zeros_like(hidden_states)
import torch.nn.functional as F
w1_scale = w13_scale
w2_scale = w2_scale
_, bs0_w13, bs1_w13 = w1_scale.shape
_, bs0_w2, bs1_w2 = w2_scale.shape
sel_experts = topk_ids.shape[1]
if hidden_states.shape[0] == 1:
for id in range(sel_experts):
expert_idx = topk_ids[0][id]
expert_w1 = w13_weight[expert_idx].contiguous()
expert_w2 = w2_weight[expert_idx].contiguous()
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
dim0, dim1 = expert_w1.shape
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
expert_w1 = (
expert_w1
.view(bs0_w13, b0, bs1_w13, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w13, bs1_w13, -1)
.float()
.to(hidden_states.device)
* ws1
)
expert_w1 = (
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
dim0, dim1 = expert_w2.shape
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
# assert (bs0, bs1, 1)==ws2.shape
expert_w2 = (
expert_w2
.view(bs0_w2, b0, bs1_w2, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w2, bs1_w2, -1)
.float()
.to(hidden_states.device)
* ws2
)
expert_w2 = (
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
expert_weights = topk_weights[0][id].to(hidden_states.dtype)
x = hidden_states
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
final_hidden_states += current_hidden_states
else:
for expert_idx in range(num_experts):
# topk_ids [tokens, experts] => sample:[10, 8]
# expert_mask [tokens, experts] => sample:[10, 8]
expert_mask = topk_ids == expert_idx
idx = torch.where(expert_mask)[0]
if idx.numel() == 0:
continue
expert_w1 = w13_weight[expert_idx].contiguous()
expert_w2 = w2_weight[expert_idx].contiguous()
ws1 = w1_scale[expert_idx].unsqueeze(2).contiguous()
ws2 = w2_scale[expert_idx].unsqueeze(2).contiguous()
dim0, dim1 = expert_w1.shape
b0, b1 = dim0 // bs0_w13, dim1 // bs1_w13
# assert (bs0, bs1, 1)==ws1.shape, f"bs0, bs1, 1 is {bs0},{bs1}, 1, <==> {ws1.shape}"
expert_w1 = (
expert_w1
.view(bs0_w13, b0, bs1_w13, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w13, bs1_w13, -1)
.float()
.to(hidden_states.device)
* ws1
)
expert_w1 = (
expert_w1.reshape(bs0_w13, bs1_w13, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
dim0, dim1 = expert_w2.shape
b0, b1 = dim0 // bs0_w2, dim1 // bs1_w2
# assert (bs0, bs1, 1)==ws2.shape
expert_w2 = (
expert_w2
.view(bs0_w2, b0, bs1_w2, b1)
.permute(0, 2, 1, 3)
.reshape(bs0_w2, bs1_w2, -1)
.float()
.to(hidden_states.device)
* ws2
)
expert_w2 = (
expert_w2.reshape(bs0_w2, bs1_w2, b0, b1)
.permute(0, 2, 1, 3)
.reshape(dim0, dim1)
.to(hidden_states.dtype)
)
# [seq, experts]
expert_weights = (
topk_weights.masked_select(expert_mask)
.unsqueeze(1)
.to(hidden_states.dtype)
)
x = hidden_states[idx]
x = F.linear(x, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
current_hidden_states = current_hidden_states.to(x.dtype)
# final_hidden_states[idx] += current_hidden_states
final_hidden_states.index_add_(0, idx, current_hidden_states)
final_hidden_states = final_hidden_states.reshape(batch_seq_all, hidden_dims)
return final_hidden_states
def fused_mlp_mm_fp8(
hidden_states: torch.Tensor,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
use_fp8_w8a8: bool = True,
w13_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a13_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape_w13: Optional[List[int]] = None,
block_shape_w2: Optional[List[int]] = None,
):
def fp8_to_fp16(inp, scale, block_size, trans_type):
inp_t = inp.to(trans_type)
inp_t = split_last_two_dims_into_blocks(inp_t, block_size[0], block_size[1])
assert scale.size(0) == inp_t.size(-4)
assert scale.size(1) == inp_t.size(-3)
inp_t = inp_t * scale.unsqueeze(-1).unsqueeze(-1)
inp_t = merge_blocks_to_original_layout(inp_t, block_size[0], block_size[1])
return inp_t.to(trans_type)
w13_weight = w13_weight.contiguous()
w2_weight = w2_weight.contiguous()
w13_scale = w13_scale.contiguous()
w2_scale = w2_scale.contiguous()
w13_fp = fp8_to_fp16(w13_weight, w13_scale, block_shape_w13, hidden_states.dtype)
w2_fp = fp8_to_fp16(w2_weight, w2_scale, block_shape_w2, hidden_states.dtype)
out = hidden_states @ w13_fp
out = torch.chunk(out, 2, dim=-1)
out = F.silu(out[0]) * out[1]
out = out @ w2_fp
return out
def mla_matmul_scale(input: torch.Tensor, weight: torch.Tensor, scale: float):
output = torch.matmul(input, weight)
output = output * scale
output = output.to(input.dtype)
return output
def mla_matmul(input: torch.Tensor, weight: torch.Tensor):
output = torch.matmul(input, weight)
output = output.to(input.dtype)
return output

View File

@@ -0,0 +1,146 @@
from typing import List, Optional, Tuple, Union
import torch
from torch import Generator
from torch_vacc._vacc_libs import _torch_vacc
def fuse_moe_prefill_stage0_qwen(
hidden_states,
rms_residual,
rms_weight,
gate_weight,
rms_hidden_state_opt: Optional[torch.Tensor] = None,
zero_moe_hidden_state_opt: Optional[torch.Tensor] = None,
topk_ids_opt: Optional[torch.Tensor] = None,
topk_weight_opt: Optional[torch.Tensor] = None,
):
return _torch_vacc.fuse_moe_prefill_stage0_qwen(
hidden_states,
rms_residual,
rms_weight,
gate_weight,
rms_hidden_state_opt,
zero_moe_hidden_state_opt,
topk_ids_opt,
topk_weight_opt,
)
def fuse_moe_decode_qwen(
hidden_states,
rms_residual,
rms_weight,
moe_weight_13,
moe_weight_2,
moe_weight_13_dequat,
moe_weight_2_dequant,
gate_weight,
block_size_13,
block_size_2,
world_size: int,
rank: int,
group_id: int,
dev_info: List[int] = None,
output: Optional[torch.Tensor] = None,
):
if 0 == len(dev_info):
dev_info = [i | (i << 16) for i in range(world_size)]
return _torch_vacc.fuse_moe_decode_qwen(
hidden_states,
rms_residual,
rms_weight,
moe_weight_13,
moe_weight_2,
moe_weight_13_dequat,
moe_weight_2_dequant,
gate_weight,
block_size_13,
block_size_2,
world_size,
rank,
group_id,
dev_info,
output,
)
def rot_pos_emb_qwenvl(grid_thw: List[List[int]],
hidden_size: int,
head_num: int,
spatial_merge_size: int,
dtype: torch.dtype,
device: Union[int, str, torch.device] = "vacc"):
#assert out_tensor.device.type == "vacc", f"please target vacc device, now is {out_tensor.device}"
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("vacc", device)
thws = []
for i in grid_thw:
thws.extend(i)
return _torch_vacc.rot_pos_emb_qwenvl(thws,
hidden_size,
head_num,
spatial_merge_size,
dtype,
device)
def fast_pos_embed_interpolate_qwenvl(weight: torch.Tensor,
grid_thw: List[List[int]],
num_grid_per_side: int,
spatial_merge_size: int,
hidden_dim: int):
thws = []
for i in grid_thw:
thws.extend(i)
return _torch_vacc.fast_pos_embed_interpolate_qwenvl(weight,
thws,
num_grid_per_side,
spatial_merge_size,
hidden_dim)
# qwen2_vl and qwen3_vl img preocess op is same
def qwen2vl_img_preprocess(
image: "torch.Tensor",
do_resize: bool,
min_pixels: int,
max_pixels: int,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
resized_height: int,
resized_width: int,
interpolation: int, #Optional["F.InterpolationMode"],
patch_size: int,
temporal_patch_size: int,
merge_size: int,
image_mean0: float,
image_mean1: float,
image_mean2: float,
image_std0: float,
image_std1: float,
image_std2: float,
# batch_size: int = 1,
# grid_t: int = 1,
# channel: int = 3,
# output: Optional[torch.Tensor] = None
):
assert image.device.type == "vacc", f"please target vacc device, now is {image.device}"
return _torch_vacc.qwen2vl_img_preprocess(
image,
do_resize,
min_pixels,
max_pixels,
do_rescale,
rescale_factor,
do_normalize,
resized_height,
resized_width,
interpolation,
patch_size,
temporal_patch_size,
merge_size,
image_mean0, image_mean1, image_mean2,
image_std0, image_std1, image_std2
)

View File

@@ -0,0 +1,107 @@
import threading
import traceback
from typing import List
from .._vacc_libs import _torch_vacc
_initialized = False
_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls = []
_is_in_bad_fork = getattr(_torch_vacc, "_vacc_in_bad_fork", lambda: False)
def is_initialized():
r"""Returns whether PyTorch's VACC state has been initialized."""
return _initialized and not _is_in_bad_fork()
class _LazySeedTracker:
# Since seeding is memory-less, only track the latest seed.
# Note: `manual_seed_all` followed by `manual_seed` overwrites
# the seed on current device. We track the order of **latest**
# calls between these two API.
def __init__(self):
self.manual_seed_all_cb = None
self.manual_seed_cb = None
self.call_order = []
def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback)
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback)
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
def get_calls(self) -> List:
return self.call_order
_lazy_seed_tracker = _LazySeedTracker()
def _lazy_call(callable, **kwargs):
if is_initialized():
callable()
else:
# TODO(torch_deploy): this accesses linecache, which attempts to read the
# file system to get traceback info. Patch linecache or do something
# else here if this ends up being important.
global _lazy_seed_tracker
if kwargs.get("seed_all", False):
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
elif kwargs.get("seed", False):
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
else:
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))
class DeferredVaccCallError(Exception):
pass
def _lazy_init():
"""Initialize VACC device state."""
global _initialized, _queued_calls
if _initialized or hasattr(_tls, "is_initializing"):
return
with _initialization_lock:
if _initialized:
return
# It is important to prevent other threads from entering _lazy_init
# immediately, while we are still guaranteed to have the GIL, because some
# of the C calls we make below will release the GIL
if _is_in_bad_fork():
raise RuntimeError(
"Cannot re-initialize VACC in forked subprocess. To use VACC with "
"multiprocessing, you must use the 'spawn' start method"
)
_torch_vacc._vacc_init()
_tls.is_initializing = True
for calls in _lazy_seed_tracker.get_calls():
if calls:
_queued_calls.append(calls)
try:
for queued_call, orig_traceback in _queued_calls:
try:
queued_call()
except Exception as e:
msg = (
f"VACC call failed lazily at initialization with error: {str(e)}\n\n"
f"VACC call was originally invoked at:\n\n{''.join(orig_traceback)}"
)
raise DeferredVaccCallError(msg) from e
finally:
delattr(_tls, "is_initializing")
_initialized = True

535
torch_vacc/vacc/memory.py Normal file
View File

@@ -0,0 +1,535 @@
import collections
import contextlib
import warnings
from typing import Tuple
import torch
from torch._utils import _get_device_index
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
from .lazy_initialize import is_initialized, _lazy_init
__all__ = [
"mem_get_info",
# "caching_allocator_alloc",
# "caching_allocator_delete",
"set_per_process_memory_fraction",
"empty_cache",
"memory_stats",
"memory_stats_as_nested_dict",
"reset_accumulated_memory_stats",
"reset_peak_memory_stats",
"reset_max_memory_allocated",
"reset_max_memory_cached",
"memory_allocated",
"max_memory_allocated",
"memory_reserved",
"max_memory_reserved",
"memory_cached",
"max_memory_cached",
"memory_snapshot",
"memory_summary",
"get_allocator_backend",
]
@contextlib.contextmanager
def _free_mutex():
_torch_vacc._vacc_lock_mutex()
try:
yield
finally:
_torch_vacc._vacc_unlock_mutex()
# def caching_allocator_alloc(size, device=None, stream=None):
# r"""Performs a memory allocation using the VACC memory allocator.
# Memory is allocated for a given device and a stream, this
# function is intended to be used for interoperability with other
# frameworks. Allocated memory is released through
# :func:`~torch_vacc.vacc.caching_allocator_delete`.
# Arguments:
# size (int): number of bytes to be allocated.
# device (torch.device or int, optional): selected device. If it is
# ``None`` the default VACC device is used.
# stream (torch_vacc.vacc.Stream or int, optional): selected stream. If is ``None`` then
# the default stream for the selected device is used.
# """
# if device is None:
# device = torch_vacc.vacc.current_device()
# device = _get_device_index(device)
# if stream is None:
# stream = torch_vacc.vacc.current_stream(device)
# if isinstance(stream, torch_vacc.vacc.streams.Stream):
# stream = stream.vacc_stream
# if not isinstance(stream, int):
# raise TypeError(
# "Invalid type for stream argument, must be "
# "`torch_vacc.vacc.Stream` or `int` representing a pointer "
# "to a exisiting stream"
# )
# with torch_vacc.vacc.device(device):
# return _torch_vacc._vacc_vaccCachingAllocator_raw_alloc(size, stream)
# def caching_allocator_delete(mem_ptr):
# r"""Deletes memory allocated using the VACC memory allocator.
# Memory allocated with :func:`~torch_vacc.vacc.caching_allocator_alloc`.
# is freed here. The associated device and stream are tracked inside
# the allocator.
# Arguments:
# mem_ptr (int): memory address to be freed by the allocator.
# """
# _torch_vacc._vacc_vaccCachingAllocator_raw_delete(mem_ptr)
def set_per_process_memory_fraction(fraction, device=None) -> None:
r"""Set memory fraction for a process.
The fraction is used to limit an caching allocator to allocated memory on a VACC device.
The allowed value equals the total visible memory multiplied fraction.
If trying to allocate more than the allowed value in a process, will raise an out of
memory error in allocator.
Arguments:
fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
device (torch.device or int, optional): selected device. If it is
``None`` the default VACC device is used.
.. note::
In general, the total available free memory is less than the total capacity.
"""
_lazy_init()
if device is None:
device = torch_vacc.vacc.current_device()
device = _get_device_index(device)
if not isinstance(fraction, float):
raise TypeError("Invalid type for fraction argument, must be `float`")
if fraction < 0 or fraction > 1:
raise ValueError(
"Invalid fraction value: {}. " "Allowed range: 0~1".format(fraction)
)
_torch_vacc._vacc_setMemoryFraction(fraction, device)
def empty_cache():
r"""Releases all unoccupied cached memory currently held by the caching
allocator so that those can be used in other VACC application and visible in
`nvidia-smi`.
.. note::
:func:`~torch_vacc.vacc.empty_cache` doesn't increase the amount of VACC
memory available for PyTorch. However, it may help reduce fragmentation
of VACC memory in certain cases.
"""
if is_initialized():
_torch_vacc._vacc_emptyCache()
def memory_stats(device=None):
"""Returns a dictionary of VACC memory allocator statistics for a
given device.
The return value of this function is a dictionary of statistics, each of
which is a non-negative integer.
Core statistics:
- ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
number of allocation requests received by the memory allocator.
- ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
amount of allocated memory.
- ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
number of reserved segments from ``vaccMalloc()``.
- ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
amount of reserved memory.
- ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
number of active memory blocks.
- ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
amount of active memory.
- ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
number of inactive, non-releasable memory blocks.
- ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
amount of inactive, non-releasable memory.
For these core statistics, values are broken down as follows.
Pool type:
- ``all``: combined statistics across all memory pools.
- ``large_pool``: statistics for the large allocation pool
(as of October 2019, for size >= 1MB allocations).
- ``small_pool``: statistics for the small allocation pool
(as of October 2019, for size < 1MB allocations).
Metric type:
- ``current``: current value of this metric.
- ``peak``: maximum value of this metric.
- ``allocated``: historical total increase in this metric.
- ``freed``: historical total decrease in this metric.
In addition to the core statistics, we also provide some simple event
counters:
- ``"num_alloc_retries"``: number of failed ``vaccMalloc`` calls that
result in a cache flush and retry.
- ``"num_ooms"``: number of out-of-memory errors thrown.
The caching allocator can be configured via ENV to not split blocks larger than a
defined size (see Memory Management section of the Cuda Semantics documentation).
This helps avoid memory framentation but may have a performance
penalty. Additional outputs to assist with tuning and evaluating impact:
- ``"max_split_size"``: blocks above this size will not be split.
- ``"oversize_allocations.{current,peak,allocated,freed}"``:
number of over-size allocation requests received by the memory allocator.
- ``"oversize_segments.{current,peak,allocated,freed}"``:
number of over-size reserved segments from ``cudaMalloc()``.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistics for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
result = []
def _recurse_add_to_result(prefix, obj):
if isinstance(obj, dict):
if len(prefix) > 0:
prefix += "."
for k, v in obj.items():
_recurse_add_to_result(prefix + k, v)
else:
result.append((prefix, obj))
stats = memory_stats_as_nested_dict(device=device)
_recurse_add_to_result("", stats)
result.sort()
return collections.OrderedDict(result)
def memory_stats_as_nested_dict(device=None):
r"""Returns the result of :func:`~torch_vacc.vacc.memory_stats` as a nested dictionary."""
device = _get_device_index(device, optional=True)
return _torch_vacc._vacc_memoryStats(device)
def reset_accumulated_memory_stats(device=None):
r"""Resets the "accumulated" (historical) stats tracked by the VACC memory allocator.
See :func:`~torch_vacc.vacc.memory_stats` for details. Accumulated stats correspond to
the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
`"num_alloc_retries"` and `"num_ooms"`.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
device = _get_device_index(device, optional=True)
return _torch_vacc._vacc_resetAccumulatedMemoryStats(device)
def reset_peak_memory_stats(device=None):
r"""Resets the "peak" stats tracked by the VACC memory allocator.
See :func:`~torch_vacc.vacc.memory_stats` for details. Peak stats correspond to the
`"peak"` key in each individual stat dict.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
device = _get_device_index(device, optional=True)
return _torch_vacc._vacc_resetPeakMemoryStats(device)
def reset_max_memory_allocated(device=None):
r"""Resets the starting point in tracking maximum VACC memory occupied by
tensors for a given device.
See :func:`~torch_vacc.vacc.max_memory_allocated` for details.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
.. warning::
This function now calls :func:`~torch_vacc.vacc.reset_peak_memory_stats`, which resets
/all/ peak memory stats.
"""
# warnings.warn(
# "torch_vacc.vacc.reset_max_memory_allocated now calls torch_vacc.vacc.reset_peak_memory_stats, "
# "which resets /all/ peak memory stats.",
# DeprecationWarning,
# )
return reset_peak_memory_stats(device=device)
def reset_max_memory_cached(device=None):
r"""Resets the starting point in tracking maximum VACC memory managed by the
caching allocator for a given device.
See :func:`~torch_vacc.vacc.max_memory_cached` for details.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
.. warning::
This function now calls :func:`~torch_vacc.vacc.reset_peak_memory_stats`, which resets
/all/ peak memory stats.
"""
# warnings.warn(
# "torch_vacc.vacc.reset_max_memory_cached now calls torch_vacc.vacc.reset_peak_memory_stats, "
# "which resets /all/ peak memory stats.",
# DeprecationWarning,
# )
return reset_peak_memory_stats(device=device)
def memory_allocated(device=None):
r"""Returns the current VACC memory occupied by tensors in bytes for a given
device.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
return memory_stats(device=device)["allocated_bytes.all.current"]
def max_memory_allocated(device=None):
r"""Returns the maximum VACC memory occupied by tensors in bytes for a given
device.
By default, this returns the peak allocated memory since the beginning of
this program. :func:`~torch_vacc.vacc.reset_peak_stats` can be used to
reset the starting point in tracking this metric. For example, these two
functions can measure the peak allocated memory usage of each iteration in a
training loop.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
return memory_stats(device=device)["allocated_bytes.all.peak"]
def memory_reserved(device=None):
r"""Returns the current VACC memory managed by the caching allocator in bytes
for a given device.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
return memory_stats(device=device)["reserved_bytes.all.current"]
def max_memory_reserved(device=None):
r"""Returns the maximum VACC memory managed by the caching allocator in bytes
for a given device.
By default, this returns the peak cached memory since the beginning of this
program. :func:`~torch_vacc.vacc.reset_peak_stats` can be used to reset
the starting point in tracking this metric. For example, these two functions
can measure the peak cached memory amount of each iteration in a training
loop.
Arguments:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
return memory_stats(device=device)["reserved_bytes.all.peak"]
def memory_cached(device=None):
r"""Deprecated; see :func:`~torch_vacc.vacc.memory_reserved`."""
# warnings.warn(
# "torch_vacc.vacc.memory_cached has been renamed to torch_vacc.vacc.memory_reserved",
# DeprecationWarning,
# )
return memory_reserved(device=device)
def max_memory_cached(device=None):
r"""Deprecated; see :func:`~torch_vacc.vacc.max_memory_reserved`."""
# warnings.warn(
# "torch_vacc.vacc.max_memory_cached has been renamed to torch_vacc.vacc.max_memory_reserved",
# DeprecationWarning,
# )
return max_memory_reserved(device=device)
def memory_snapshot():
r"""Returns a snapshot of the VACC memory allocator state across all devices.
Interpreting the output of this function requires familiarity with the
memory allocator internals.
"""
return _torch_vacc._vacc_memorySnapshot()
def _format_size(sz, pref_sz):
prefixes = ["B ", "KB", "MB", "GB", "TB", "PB"]
prefix = prefixes[0]
for new_prefix in prefixes[1:]:
if pref_sz < 768 * 1024:
break
prefix = new_prefix
sz //= 1024
pref_sz /= 1024
return "{:7d} {}".format(sz, prefix)
def _format_count(cnt, pref_cnt):
prefixes = [" ", "K", "M"]
prefix = prefixes[0]
for new_prefix in prefixes[1:]:
if pref_cnt < 750 * 1000:
break
prefix = new_prefix
cnt //= 1000
pref_cnt /= 1000
return "{:7d} {} ".format(cnt, prefix)
def create_metrics_to_display():
metrics_to_display = [
("allocated_bytes", "Allocated memory", _format_size),
("active_bytes", "Active memory", _format_size),
("reserved_bytes", "VACC reserved memory", _format_size),
("inactive_split_bytes", "Non-releasable memory", _format_size),
("allocation", "Allocations", _format_count),
("active", "Active allocs", _format_count),
("segment", "VACC reserved segments", _format_count),
("inactive_split", "Non-releasable allocs", _format_count),
]
lines = []
lines.append("=" * 75)
lines.append(" {_:16} PyTorch VACC memory summary, device ID {device:<18d} ")
lines.append("-" * 75)
lines.append(
" {_:9} VACC OOMs: {num_ooms:<13d} | {_:6} vaccMalloc retries: {num_alloc_retries:<9d} "
)
lines.append("=" * 75)
lines.append(
" Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed "
)
return metrics_to_display, lines
def memory_summary(device=None, abbreviated=False):
r"""Returns a human-readable printout of the current memory allocator
statistics for a given device.
This can be useful to display periodically during training, or when
handling out-of-memory exceptions.
Arguments:
device (torch.device or int, optional): selected device. Returns
printout for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
abbreviated (bool, optional): whether to return an abbreviated summary
(default: False).
"""
device = _get_device_index(device, optional=True)
stats = memory_stats(device=device)
metrics_to_display, lines = create_metrics_to_display()
for metric_key, metric_name, formatter in metrics_to_display:
lines.append("-" * 75)
submetrics = [("all", metric_name)]
if not abbreviated:
submetrics.append(("large_pool", " from large pool"))
submetrics.append(("small_pool", " from small pool"))
current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
None,
None,
None,
None,
)
for submetric_key, submetric_name in submetrics:
prefix = metric_key + "." + submetric_key + "."
current = stats[prefix + "current"]
peak = stats[prefix + "peak"]
allocated = stats[prefix + "allocated"]
freed = stats[prefix + "freed"]
if current_prefval is None:
current_prefval = current
peak_prefval = peak
allocated_prefval = allocated
freed_prefval = freed
lines.append(
" {:<21} | {} | {} | {} | {} ".format(
submetric_name,
formatter(current, current_prefval),
formatter(peak, peak_prefval),
formatter(allocated, allocated_prefval),
formatter(freed, freed_prefval),
),
)
metrics_to_display = [
("oversize_allocations", "Oversize allocations", _format_count),
("oversize_segments", "Oversize VACC segments", _format_count),
]
for metric_key, metric_name, formatter in metrics_to_display:
lines.append("-" * 75)
prefix = metric_key + "."
current = stats[prefix + "current"]
peak = stats[prefix + "peak"]
allocated = stats[prefix + "allocated"]
freed = stats[prefix + "freed"]
lines.append(
" {:<21} | {} | {} | {} | {} ".format(
metric_name,
formatter(current, current),
formatter(peak, peak),
formatter(allocated, allocated),
formatter(freed, freed),
),
)
lines.append("=" * 75)
fmt_dict = {"_": "", "device": device}
for k, v in stats.items():
fmt_dict[k.replace(".", "-")] = v
return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
def mem_get_info(device=None) -> Tuple[int, int]:
r"""Returns the global free and total VACC memory for a given
device using vaccrtMemGetInfo.
Args:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch_vacc.vacc.current_device`,
if :attr:`device` is ``None`` (default).
"""
_lazy_init()
if device is None:
device = torch_vacc.vacc.current_device()
device = _get_device_index(device)
return _torch_vacc._vacc_getDeviceMemories(device)
def get_allocator_backend() -> str:
r"""Returns a string describing the active allocator backend as set by
``PYTORCH_VACC_ALLOC_CONF``. Currently available backends are
``native`` (PyTorch's native caching allocator).
"""
return _torch_vacc._vacc_getAllocatorBackend()

179
torch_vacc/vacc/random.py Normal file
View File

@@ -0,0 +1,179 @@
from typing import Union, List, Iterable
import torch
from torch import Tensor
from . import _lazy_call, _lazy_init, current_device, device_count
__all__ = [
"get_rng_state",
"get_rng_state_all",
"set_rng_state",
"set_rng_state_all",
"manual_seed",
"manual_seed_all",
"seed",
"seed_all",
"initial_seed",
]
# Random Number Generator related functions (https://pytorch.org/docs/stable/cuda.html#random-number-generator)
def get_rng_state(device: Union[int, str, torch.device] = "vacc") -> Tensor:
r"""Returns the random number generator state of the specified GPU as a ByteTensor.
Args:
device (torch.device or int, optional): The device to return the RNG state of.
Default: ``'vacc'`` (i.e., ``torch.device('vacc')``, the current VACC device).
.. warning::
This function eagerly initializes VACC.
"""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("vacc", device)
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
return default_generator.get_state()
def get_rng_state_all() -> List[Tensor]:
r"""Returns a list of ByteTensor representing the random number states of all devices."""
results = []
for i in range(device_count()):
results.append(get_rng_state(i))
return results
def set_rng_state(
new_state: Tensor, device: Union[int, str, torch.device] = "vacc"
) -> None:
r"""Sets the random number generator state of the specified GPU.
Args:
new_state (torch.ByteTensor): The desired state
device (torch.device or int, optional): The device to set the RNG state.
Default: ``'vacc'`` (i.e., ``torch.device('vacc')``, the current VACC device).
"""
with torch._C._DisableFuncTorch():
new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("vacc", device)
def cb():
idx = device.index
if idx is None:
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
default_generator.set_state(new_state_copy)
_lazy_call(cb)
def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
r"""Sets the random number generator state of all devices.
Args:
new_states (Iterable of torch.ByteTensor): The desired state for each device"""
for i, state in enumerate(new_states):
set_rng_state(state, i)
def manual_seed(seed: int) -> None:
r"""Sets the seed for generating random numbers for the current GPU.
It's safe to call this function if VACC is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
.. warning::
If you are working with a multi-GPU model, this function is insufficient
to get determinism. To seed all GPUs, use :func:`manual_seed_all`.
"""
seed = int(seed)
def cb():
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
default_generator.manual_seed(seed)
_lazy_call(cb, seed=True)
def manual_seed_all(seed: int) -> None:
r"""Sets the seed for generating random numbers on all GPUs.
It's safe to call this function if VACC is not available; in that
case, it is silently ignored.
Args:
seed (int): The desired seed.
"""
seed = int(seed)
def cb():
for i in range(device_count()):
default_generator = torch.vacc.default_generators[i]
default_generator.manual_seed(seed)
_lazy_call(cb, seed_all=True)
def seed() -> None:
r"""Sets the seed for generating random numbers to a random number for the current GPU.
It's safe to call this function if VACC is not available; in that
case, it is silently ignored.
.. warning::
If you are working with a multi-GPU model, this function will only initialize
the seed on one GPU. To initialize all GPUs, use :func:`seed_all`.
"""
def cb():
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
default_generator.seed()
_lazy_call(cb)
def seed_all() -> None:
r"""Sets the seed for generating random numbers to a random number on all GPUs.
It's safe to call this function if VACC is not available; in that
case, it is silently ignored.
"""
def cb():
random_seed = 0
seeded = False
for i in range(device_count()):
default_generator = torch.vacc.default_generators[i]
if not seeded:
default_generator.seed()
random_seed = default_generator.initial_seed()
seeded = True
else:
default_generator.manual_seed(random_seed)
_lazy_call(cb)
def initial_seed() -> int:
r"""Returns the current random seed of the current GPU.
.. warning::
This function eagerly initializes VACC.
"""
_lazy_init()
idx = current_device()
default_generator = torch.vacc.default_generators[idx]
return default_generator.initial_seed()

327
torch_vacc/vacc/streams.py Normal file
View File

@@ -0,0 +1,327 @@
import ctypes
from typing import Any, Optional
import torch
from packaging import version
from torch._utils import _get_device_index
try:
from torch._streambase import _StreamBase, _EventBase
except ImportError:
# torch <= 2.1
_StreamBase = _EventBase = object
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
from ._device import device
from .lazy_initialize import _lazy_init
# remove torch version arch-suffix(i.e. +cpu)
torch_version = torch.__version__.split('+')[0]
class _StreamCommon:
"""Wrapper around a VACC stream.
A VACC stream is a linear sequence of execution that belongs to a specific
device, independent from other streams.
Args:
device(torch.device or int, optional): a device on which to allocate
the stream. If :attr:`device` is ``None`` (default) or a negative
integer, this will use the current device.
priority(int, optional): priority of the stream. Can be either
-1 (high priority) or 0 (low priority). By default, streams have
priority 0.
"""
def __new__(cls, device=None, priority=0, **kwargs):
if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
else:
with torch_vacc.vacc.device(device):
return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
def wait_event(self, event):
event.wait(self)
def record_event(self, event=None):
"""Records an event.
Args:
event (torch_vacc.Event, optional): event to record. If not given, a new one
will be allocated.
Returns:
Recorded event.
"""
if event is None:
event = Event()
event.record(self)
return event
def wait_stream(self, stream):
"""Synchronizes with another stream.
All future work submitted to this stream will wait until all kernels
submitted to a given stream at the time of call complete.
Args:
stream (Stream): a stream to synchronize.
"""
self.wait_event(stream.record_event())
def query(self):
return super().query()
def synchronize(self):
super().synchronize()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.vacc_stream)
def __eq__(self, o):
if isinstance(o, Stream):
return super().__eq__(o)
return False
def __hash__(self):
return hash((self.vacc_stream, self.device))
def __repr__(self):
return f"torch_vacc.vacc.Stream device={self.device} vacc_stream={self.vacc_stream:#x}"
if version.parse(torch_version) <= version.parse("2.1"):
# torch <= 2.1
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
pass
elif version.parse(torch_version) < version.parse("2.6"):
# torch < 2.6
class Stream(_torch_vacc._VACCStreamBase, _StreamBase, _StreamCommon):
pass
else:
# torch >= 2.6
class Stream(_torch_vacc._VACCStreamBase, _StreamCommon):
pass
class _EventCommon:
"""Wrapper around a VACC event.
VACC events are synchronization markers that can be used to monitor the
device's progress, to accurately measure timing, and to synchronize VACC
streams.
The underlying VACC events are lazily initialized when the event is first
recorded or exported to another process. After creation, only streams on the
same device may record the event. However, streams on any device can wait on
the event.
Args:
calc_time (bool, optional): indicates if the event should measure time
(default: ``False``)
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
"""
def __new__(cls, enable_timing=False, blocking=False):
return super(Event, cls).__new__(
cls,
calc_time=enable_timing,
blocking=blocking,
)
def record(self, stream=None):
"""Records the event in a given stream.
Uses ``torch_vacc.vacc.current_stream()`` if no stream is specified. The
stream's device must match the event's device."""
if stream is None:
stream = torch_vacc.vacc.current_stream()
super().record(stream)
def wait(self, stream=None):
"""Makes all future work submitted to the given stream wait for this
event.
Use ``torch_vacc.vacc.current_stream()`` if no stream is specified.
.. note:: This is a wrapper around ``vaccrtStreamWaitEvent()``
"""
if stream is None:
stream = torch_vacc.vacc.current_stream()
super().wait(stream)
def query(self):
"""Checks if all work currently captured by event has completed.
Returns:
A boolean indicating if all work currently captured by event has
completed.
"""
return super().query()
def elapsed_time(self, end_event):
"""Returns the time elapsed in milliseconds after the event was
recorded and before the end_event was recorded.
"""
return super().elapsed_time(end_event)
def synchronize(self):
r"""Waits for the event to complete.
Waits until the completion of all work currently captured in this event.
This prevents the CPU thread from proceeding until the event completes.
.. note:: This is a wrapper around ``vaccEventSynchronize()``.
"""
super().synchronize()
@property
def _as_parameter_(self):
return ctypes.c_void_p(self.vacc_event)
def __repr__(self):
if self.vacc_event:
return f"<torch_vacc.vacc.Event {self._as_parameter_.value:#x}>"
else:
return "<torch_vacc.vacc.Event uninitialized>"
if version.parse(torch_version) <= version.parse("2.1"):
# torch <= 2.1
class Event(_torch_vacc._VACCEventBase, _EventCommon):
pass
elif version.parse(torch_version) < version.parse("2.6"):
# torch < 2.6
class Event(_torch_vacc._VACCEventBase, _EventBase, _EventCommon):
pass
else:
# torch >= 2.6
class Event(_torch_vacc._VACCEventBase, _EventCommon):
pass
class StreamContext:
r"""Context-manager that selects a given stream.
All VACC kernels queued within its context will be enqueued on a selected
stream.
Args:
stream (stream): selected stream. This manager is a no-op if it's
``None``.
.. note:: Streams are per-device.
"""
cur_stream: Optional["torch_vacc.vacc.Stream"]
def __init__(self, stream: Optional["torch_vacc.vacc.Stream"]):
self.stream = stream
self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting():
if self.idx is None:
self.idx = -1
self.src_prev_stream = (
None
if not torch.jit.is_scripting()
else torch_vacc.vacc.default_stream(None)
)
self.dst_prev_stream = (
None
if not torch.jit.is_scripting()
else torch_vacc.vacc.default_stream(None)
)
def __enter__(self):
# Local cur_stream variable for type refinement
cur_stream = self.stream
# Return if stream is None or VACC device not available
if cur_stream is None or self.idx == -1:
return
self.src_prev_stream = torch_vacc.vacc.current_stream(None)
# If the stream is not on the current device, then
# set the current stream on the device
if self.src_prev_stream.device != cur_stream.device:
with device(cur_stream.device):
self.dst_prev_stream = torch_vacc.vacc.current_stream(cur_stream.device)
torch_vacc.vacc.set_stream(cur_stream)
def __exit__(self, type: Any, value: Any, traceback: Any):
# Local cur_stream variable for type refinement
cur_stream = self.stream
# If stream is None or no VACC device available, return
if cur_stream is None or self.idx == -1:
return
# Reset the stream on the original device
# and destination device
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
torch_vacc.vacc.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
torch_vacc.vacc.set_stream(self.src_prev_stream) # type: ignore[arg-type]
def stream(stream: Optional["torch_vacc.vacc.Stream"]) -> StreamContext:
r"""Wrapper around the Context-manager StreamContext that
selects a given stream.
Arguments:
stream (Stream): selected stream. This manager is a no-op if it's
``None``.
"""
return StreamContext(stream)
def set_stream(stream: Stream):
r"""Sets the current stream.This is a wrapper API to set the stream.
Usage of this function is discouraged in favor of the ``stream``
context manager.
Args:
stream (Stream): selected stream. This function is a no-op
if this argument is ``None``.
"""
if stream is None:
return
_torch_vacc._vacc_setStream(
stream_id=stream.stream_id,
device_index=stream.device_index,
device_type=stream.device_type,
)
def current_stream(device=None) -> Stream:
r"""Returns the currently selected :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
the currently selected :class:`Stream` for the current device, given
by :func:`~torch_vacc.vacc.current_device`, if :attr:`device` is ``None``
(default).
"""
_lazy_init()
streamdata = _torch_vacc._vacc_getCurrentStream(
_get_device_index(device, optional=True)
)
return Stream(
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
)
def default_stream(device=None) -> Stream:
r"""Returns the default :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
the default :class:`Stream` for the current device, given by
:func:`_torch_vacc.current_device`, if :attr:`device` is ``None``
(default).
"""
_lazy_init()
streamdata = _torch_vacc._vacc_getDefaultStream(
_get_device_index(device, optional=True)
)
return Stream(
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
)

2
torch_vacc/version.py Normal file
View File

@@ -0,0 +1,2 @@
__all__ = ['__version__']
__version__ = '1.3.3.777'

269
torch_vacc/vslog.cfg Normal file
View File

@@ -0,0 +1,269 @@
hot_update: true
- channel: 0
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "$PNAME-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 1
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vacm-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 2
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vace-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 3
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vacl-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 4
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vame-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 5
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vaml-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 6
sync: sync
priority: error
category: 0
category_extend: 0
append_cr: true
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "rt-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: true
out_type: screen
category: 0
category_extend: 0
- channel: 7
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "nn-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: true
out_type: screen
category: 0
category_extend: 0
- channel: 8
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "tm-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 9
sync: sync
priority: error
category: 0
category_extend: 0
append_cr: true
no_prefix: true
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "md-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: true
out_type: screen
category: 0
category_extend: 0
- channel: 10
sync: sync
priority: error
category: 0
category_extend: 0
append_cr: false
no_prefix: true
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "rs-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 11
sync: sync
priority: error
category: 0
category_extend: 0
append_cr: false
no_prefix: true
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vaapi-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: false
out_type: screen
category: 0
category_extend: 0
- channel: 12
sync: sync
priority: error
category: 0
category_extend: 0
-device: 0
disable: false
out_type: file
priority: error
category: 0
category_extend: 0
path: "./log/"
file: "vccl-$YEAR_$MON_$DAY_$HOUR_$MIN_$SEC_$PID"
rollback: 5
limit_size: 50 m #only support M byte
-device: 1
disable: true
out_type: screen
category: 0
category_extend: 0

31
vacc_tools/__init__.py Normal file
View File

@@ -0,0 +1,31 @@
from functools import partial
from datetime import datetime
from typing import Union, Tuple
import torch
import torch.distributed
_module_time = {}
def print_module_time(
model: torch.nn.Module, module: Union[Tuple[torch.nn.Module], torch.nn.Module]
):
def now_as_us():
return int(datetime.now().timestamp() * 1e6) # in us
def _pre_forward(suffix, m, inputs):
name = f"{type(m).__name__}.{suffix}"
_module_time[name] = now_as_us()
def _post_forward(suffix, m, inputs, outputs):
name = f"{type(m).__name__}.{suffix}"
start_time = _module_time.pop(name)
print(f"{name}: {now_as_us() - start_time} us")
for name, m in model.named_modules():
if isinstance(m, module):
m.register_forward_pre_hook(partial(_pre_forward, "forward"))
m.register_forward_hook(partial(_post_forward, "forward"))
m.register_full_backward_pre_hook(partial(_pre_forward, "backward"))
m.register_full_backward_hook(partial(_post_forward, "backward"))

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,214 @@
"""Generating tracing json files from log files.
Usage:
python -m vacc_tools.generate_trace --log-dir <directory of log files> --out-file-prefix <prefix of output file>
"""
import argparse
import json
import os
import re
import numpy as np
import tabulate
from glob import glob
from collections import defaultdict
from multiprocessing import Pool
def run_stats_on_traces(timelines):
op_cat_list = ["ODSP", "DLC", "VCCL", "CPU", "CPU_OP"]
op_stats = {op: {} for op in op_cat_list}
for line in timelines:
if '"E"' not in line: # optim 3, skip everything if not `"E"`
continue
# optim 2: using `[:-2]` instead of replace()
line = line[:-2] # remove ',\n'
try:
values = json.loads(line)
except json.decoder.JSONDecodeError:
# some log may not ends properly, just skip it
continue
if values["ph"] == "E" and values["cat"] in op_cat_list:
cat = values["cat"]
if values["name"] not in op_stats[cat]:
op_stats[cat][values["name"]] = []
if "dur" in values["args"]:
# optim 1: using `[:-2]` instead of replace()
op_stats[cat][values["name"]].append(
int(values["args"]["dur"][:-2]) # strip `us`
)
elif "values(us)" in values["args"]:
op_stats[cat][values["name"]].append(values["args"]["value(us)"])
op_tables = {}
for cat, stats in op_stats.items():
# optim 4: using list comprehension instead of for loop
table = []
for name, dur in stats.items():
dur = np.array(dur)
t = [
name,
np.min(dur),
np.max(dur),
np.sum(dur),
np.mean(dur),
np.percentile(dur, 90),
len(dur),
]
table.append(t)
table = sorted(table, key=lambda x: x[-1], reverse=True)
op_tables[cat] = tabulate.tabulate(
table,
headers=["op", "min", "max", "sum", "avg", "p90", "count"],
tablefmt="plain",
)
if cat in ["VCCL", "ODSP", "DLC"]:
op_tables["VACC-ALL"] = op_tables.get("VACC-ALL", []) + [
t + [cat] for t in table
]
total = sum([x[3] for x in op_tables["VACC-ALL"]])
op_tables["VACC-ALL"] = [t + [t[3] / total * 100] for t in op_tables["VACC-ALL"]]
op_tables["VACC-ALL"] = tabulate.tabulate(
sorted(op_tables["VACC-ALL"], key=lambda x: x[-1], reverse=True),
headers=["op", "min", "max", "sum", "avg", "p90", "count", "cat", "percent(%)"],
tablefmt="plain",
)
return op_tables
def get_rank_info(files):
# using pattern rank-<rank> in file name to get rank
for fpath in files:
rank = re.findall(r"rank-(\d+)", fpath)
if rank:
return int(rank[0])
return 0
def extract_traces(arg):
files, target_file_path, group_name, trace_token = arg
entries = [
(0, "scheduler"),
(1, "megatron"),
(2, "deepspeed"),
(3, "nn.Module"),
(10, "vacc-odsp"),
(11, "vacc-dlc"),
(12, "vacc-vccl"),
(13, "vacc-cpu"),
(14, "vacc-fallback"),
(15, "vacc-ddr"),
(20, "lib-vccl"),
]
with open(target_file_path, "w", encoding="utf-8") as trace_file:
trace_file.write("[")
for tid, thread_name in entries:
line = f'{{"cat":"__metadata","pid":{group_name},"tid":{tid},"ts":0,"ph":"M","name":"thread_name","args":{{"name":"{thread_name}"}}}},\n'
trace_file.write(line)
timelines = []
for fpath in files:
with open(fpath, "r", encoding="utf-8") as file:
# timelines += [line.split(trace_token)[1] for line in file if trace_token in line]
for line in file:
if trace_token in line:
# 找到目标字符串,取其之后的内容(包括目标字符串)
timelines.append(line.split(trace_token)[1])
try:
json.loads(timelines[-1][:-2]) # remove ',\n'
except json.decoder.JSONDecodeError:
# some log may not ends properly, just skip it
# chrome:://tracing stops reading following lines if an error encountered
# so must remove lines with error
timelines.pop()
for line in timelines[:-1]:
trace_file.write(line)
# fixing JSON format error by removing last comma in a list
trace_file.write(timelines[-1].replace(",\n", "\n"))
trace_file.write("]")
op_stats = run_stats_on_traces(timelines)
with open(
target_file_path.replace(".json", ".txt"), "w", encoding="utf-8"
) as op_stats_file:
for cat, tables in op_stats.items():
op_stats_file.write(f"{cat}".center(80, "-") + "\n")
op_stats_file.write(tables + "\n\n")
def merge_schedule(out_file_prefix):
scheduler_data = []
for file in glob(f"{out_file_prefix}*.json"):
if file.endswith("schedule.json"):
continue
assert "rank" in file
rank = file.split("rank_")[-1].split("_")[0]
pid = None
with open(file, "r", encoding="utf-8") as f:
for line in f:
# set all schedule's pid to 0 and set all schedule's tid to rank id
if '"tid":0,' in line and "__metadata" not in line:
if pid is None:
pid = line.split('"pid":')[1].split(",")[0]
line = line.replace(f'"pid":{pid}', f'"pid":0')
line = line.replace('"tid":0,', f'"tid":{rank},')
scheduler_data.append(line)
out_file = f"{out_file_prefix}schedule.json"
with open(out_file, "w", encoding="utf-8") as f:
f.write("[\n")
f.writelines(scheduler_data[:-1])
f.write(scheduler_data[-1].replace(",\n", "\n"))
f.write("]\n")
def scan_and_generate_trace(args, trace_token):
grouped_files = defaultdict(list)
for root, dirs, files in os.walk(args.log_dir):
for filename in files:
fpath = os.path.join(root, filename)
file_size = os.path.getsize(fpath)
if file_size != 0:
group_name = filename.rsplit("_", 1)[1].split(".")[0]
grouped_files[group_name].append(fpath)
pool_args = []
for group_name, files in grouped_files.items():
rank = get_rank_info(files)
out_file = f"{args.out_file_prefix}rank_{rank}_{group_name}.json"
pool_args.append((files, out_file, group_name, trace_token))
with Pool(len(grouped_files)) as p:
p.map(extract_traces, pool_args)
if args.merge_schedule:
merge_schedule(args.out_file_prefix)
if __name__ == "__main__":
TRACE_TOKEN = "LOG_TRACE:"
current_file_path = os.path.abspath(__file__)
parent_directory = os.path.dirname(os.path.dirname(current_file_path))
find_directory = os.path.join(parent_directory, "log")
parser = argparse.ArgumentParser()
parser.add_argument(
"--log-dir", default=find_directory, type=str, help="directory of log files"
)
parser.add_argument("--out-file-prefix", default="timeline_", type=str)
parser.add_argument("--merge-schedule", action="store_true")
args = parser.parse_args()
scan_and_generate_trace(args, TRACE_TOKEN)
print("Scan and trace generation done!")

View File

@@ -0,0 +1,151 @@
from contextlib import contextmanager
from dataclasses import fields
from typing import Dict, Tuple, List, Optional
import torch
NUM_BYTES_IN_MB = 1024**2
NUM_BYTES_IN_GB = 1024**3
class MemoryAnalyzer:
def __init__(
self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None
):
"""This memory usage analyzer will be mostly acurate only if you initialize
at the beginning and insert `get_memory_usage_in_gb` at the end of your
forward pass.
NOTE: It will have negative impact if not properly used as it stores
activations of every nn.Module's forward function and relies on user to
reset it everytime the forward pass ends.
Limitations:
1. does not work with customized operators
2. does not work with functional operators
3. it approximates activation as nn.Module.forward's output (if it's
inside the graph requires gradients), so it may not be exactly accurate.
"""
self.model = model
self.optimizer = optimizer
self.activ_addrs = set()
self.activ_memory = 0
@staticmethod
def _is_activation(x):
return torch.is_tensor(x) and x.requires_grad and x.device != "cpu"
def _get_weight_grads_addrs(self):
weights = set([p.untyped_storage().data_ptr() for p in self.model.parameters()])
grads = set(
[
p.grad.untyped_storage().data_ptr()
for p in self.model.parameters()
if p.grad is not None
]
)
return weights.union(grads)
def pack_hook(self):
def _pack_hook(x):
if self._is_activation(x):
weight_grads = self._get_weight_grads_addrs()
# NOTE: storage is more accurate than using x.nelement() * x.element_size()
data_ptr = x.untyped_storage().data_ptr()
if data_ptr not in weight_grads and data_ptr not in self.activ_addrs:
self.activ_addrs.add(data_ptr)
self.activ_memory += x.untyped_storage().size()
return x
return _pack_hook
def unpack_hook(self):
def _unpack_hook(x):
if self._is_activation(x):
weight_grads = self._get_weight_grads_addrs()
data_ptr = x.untyped_storage().data_ptr()
if data_ptr not in weight_grads and data_ptr in self.activ_addrs:
self.activ_addrs.remove(data_ptr)
self.activ_memory -= x.untyped_storage().size()
return x
return _unpack_hook
@contextmanager
def record_activation(self):
with torch.autograd.graph.saved_tensors_hooks(
self.pack_hook(), self.unpack_hook()
):
yield
@staticmethod
def get_weight_memory(model: torch.nn.Module):
weights = [
p.nelement() * p.element_size()
for p in model.parameters()
if p.device != "cpu"
]
return sum(weights)
@staticmethod
def get_gradient_memory(model: torch.nn.Module):
grads = [
p.grad.nelement() * p.grad.element_size()
for p in model.parameters()
if p.grad is not None and p.grad.device != "cpu"
]
return sum(grads)
def _sum_activation_memory(self):
return self.activ_memory
def get_optimizer_state_memory(self):
if isinstance(self.optimizer, torch.optim.AdamW):
params = sum(
[
p.nelement() * p.element_size()
for pg in self.optimizer.param_groups
for p in pg["params"]
if torch.is_tensor(p) and p.device != "cpu"
]
)
for state in self.optimizer.state.values():
params += sum(
[
v.nelement() * v.element_size()
for k, v in state.items()
if torch.is_tensor(v) and v.device != "cpu"
]
)
return params
return 0
def _get_memory_usage(self) -> Tuple[int, int, int, int]:
return (
self.get_weight_memory(self.model),
self.get_gradient_memory(self.model),
self._sum_activation_memory(),
self.get_optimizer_state_memory(),
)
def get_memory_usage_in_gb(self) -> str:
w, g, a, opt = self._get_memory_usage()
return (
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_GB:.3f} GB, "
f"weight: {w / NUM_BYTES_IN_GB:.3f} GB, "
f"gradient: {g / NUM_BYTES_IN_GB:.3f} GB, "
f"activation: {a / NUM_BYTES_IN_GB:.3f} GB, "
f"optimizer states: {opt / NUM_BYTES_IN_GB:.3f} GB"
)
def get_memory_usage_in_mb(self) -> str:
w, g, a, opt = self._get_memory_usage()
return (
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_MB:.2f} MB, "
f"weight: {w / NUM_BYTES_IN_MB:.2f} MB, "
f"gradient: {g / NUM_BYTES_IN_MB:.2f} MB, "
f"activation: {a / NUM_BYTES_IN_MB:.2f} MB, "
f"optimizer states: {opt / NUM_BYTES_IN_MB:.2f} MB"
)

View File

@@ -0,0 +1,65 @@
import argparse
import os
from collections import defaultdict
from multiprocessing import Pool
log_tag = "LOG_TRACE:"
tid_names = [
(0, "module"),
(1, "megatron"),
(2, "deepspeed"),
(10, "vacc-odsp"),
(11, "vacc-dlc"),
(12, "vacc-vccl"),
(13, "vacc-cpu"),
(14, "vacc-cpu_fallback"),
(15, "vacc-ddr"),
(20, "lib-vccl"),
]
def parse_files_of_process(args):
pid, in_files = args
out_file = "trace_" + pid + ".json"
with open(out_file, "w", encoding="utf-8") as new_file:
metadata_lines = [
f'{{"name": "thread_name","ph": "M","pid": {pid},"tid": {tid},"args": {{"name": "{name}"}}}},'
for tid, name in tid_names
]
new_file.write("[\n")
new_file.write("\n".join(metadata_lines))
new_file.write("\n")
for file_path in in_files:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if log_tag in line:
new_line = line.split(log_tag, 1)[1].strip()
new_file.write(new_line + "\n")
new_file.write("]")
def parse_directory(directory):
pro_files = defaultdict(list)
for dirpath, dirnames, filenames in os.walk(directory):
for filename in filenames:
file_path = os.path.join(dirpath, filename)
if filename.startswith("vacc") and os.path.getsize(file_path) != 0:
pid = filename.rsplit("_", 1)[1].split(".")[0]
pro_files[pid].append(file_path)
args = []
for pid, in_files in pro_files.items():
args.append((pid, in_files))
with Pool() as p:
p.map(parse_files_of_process, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="parse vacc log files and generate trace files"
)
parser.add_argument("directory", type=str, help="log directory to parse")
args = parser.parse_args()
parse_directory(args.directory)

329
vacc_tools/trace_logger.py Normal file
View File

@@ -0,0 +1,329 @@
"""
This module provides mechanisms for tracing torch's module and function's execution,
and output the trace into a json file.
User needs to set environmental variable `LOG_TRAIN_SCHEDULE=1` to enable tracing.
If not, no trace will be applied.
Inside your module, create your module's tracer functions by using `get_trace_api`.
You will get three functions:
* `@trace_time(name)`: decorator to trace the execution of a function.
```python
@trace_time("my_func")
def my_func(x):
...
```
* `@trace_autograd_function()`: decorator to trace the execution of forward
and backward of a user defined `torch.autograd.Function` operator.
```python
@trace_autograd_function()
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
...
@staticmethod
def backward(ctx, grad_output):
...
```
* `register_module_trace()`: function to register trace a model (`nn.Module`),
it applies traces recursively to a torch model by enumerating all nn.Module
and register tracer to their forward and backward function. Only applying to
top level nn.Module is recommended.
```python
model = Model()
register_module_trace(model)
```
"""
import os
import json
from contextlib import contextmanager
from dataclasses import dataclass, asdict
from datetime import datetime
from functools import partial
import torch
import torch.distributed
MODULE_TID = {"megatron": 1, "deepspeed": 2, "nn.Module": 3, "ram": 100}
# pylint: disable=missing-docstring
@dataclass
class TraceEntry:
name: str
cat: str
pid: int
tid: int
ts: int
ph: str
args: str = None
def to_json_str(self):
d = asdict(self)
if self.args is None:
d.pop("args")
return json.dumps(d, separators=(",", ": "))
class LogFiles:
def __init__(self) -> None:
self.loggers = {}
def get(self, file_prefix, rank, pid):
os.makedirs("log", exist_ok=True)
fpath = f"log/{file_prefix}-rank-{rank}_{pid}.txt"
if not fpath in self.loggers:
self.loggers[fpath] = open(fpath, "w")
return self.loggers[fpath]
def close(self):
for f in self.loggers.values():
f.close()
def __del__(self):
self.close()
def trace_logger_enabled() -> bool:
return (
"LOG_TRAIN_SCHEDULE" in os.environ and os.environ["LOG_TRAIN_SCHEDULE"] == "1"
)
class TraceLogger:
_log_files = LogFiles()
def __init__(self, category, tid=None, file_prefix=None) -> None:
self.enabled = trace_logger_enabled()
if self.enabled:
self.pid = os.getpid()
self.logger = None
self.cat = category
self._traces = {}
self.global_rank = 0
if tid is None:
self.tid = MODULE_TID.get(category, 1000)
else:
self.tid = tid
self.file_prefix = file_prefix if file_prefix is not None else self.cat
self.registered_modules = []
def _creat_logger(self) -> None:
# delay creating logger file until first log call,
# since torch.distributed may not be ready yet
if torch.distributed.is_initialized():
self.global_rank = torch.distributed.get_rank()
self.logger = TraceLogger._log_files.get(
self.file_prefix, self.global_rank, self.pid
)
def begin_trace(self, name, memory=False) -> None:
if not self.enabled:
return
if self.logger is None:
self._creat_logger()
assert self.logger is not None
name = f"{name}" # convert it to str to ensure json serializable
start_time = int(datetime.now().timestamp() * 1e6) # in us
trace = TraceEntry(name, self.cat, self.pid, self.tid, start_time, "B")
mem_trace = self._get_memory(start_time) if memory else None
if name not in self._traces:
self._traces[name] = [(trace, mem_trace)]
else: # in case call to the function is nested
self._traces[name].append((trace, mem_trace))
def end_trace(self, name, flush=False, memory=False) -> None:
if not self.enabled:
return
name = f"{name}" # convert it to str to ensure json serializable
assert self.logger is not None, "begin_trace should be called before end_trace"
assert name in self._traces, "begin_trace should be called before end_trace"
start_trace, start_mem = self._traces[name].pop()
if start_mem is not None:
self.logger.write(f"LOG_TRACE:{start_mem.to_json_str()},\n")
self.logger.write(f"LOG_TRACE:{start_trace.to_json_str()},\n")
end_time = int(datetime.now().timestamp() * 1e6) # in us
args = {"value(us)": end_time - start_trace.ts}
trace = TraceEntry(name, self.cat, self.pid, self.tid, end_time, "E", args)
self.logger.write(f"LOG_TRACE:{trace.to_json_str()},\n")
if memory:
mem_trace = self._get_memory(end_time)
self.logger.write(f"LOG_TRACE:{mem_trace.to_json_str()},\n")
if flush:
self.flush()
def flush(self) -> None:
if self.logger is not None:
self.logger.flush()
def _get_memory(self, timestamp):
args = {"value": torch.vacc.memory_allocated(self.global_rank)}
mem_trace = TraceEntry(
"memory", "memory", self.pid, MODULE_TID["ram"], timestamp, "C", args
)
return mem_trace
@contextmanager
def _trace_time(name, logger_inst, memory=False, flush=False):
if not logger_inst.enabled:
yield
return
logger_inst.begin_trace(name)
yield
logger_inst.end_trace(name, flush=flush)
SKIPED_MODULES = []
def _register_module_trace(
module: torch.nn.Module, logger_inst, flush: bool = True, forward_only=False
):
if not logger_inst.enabled:
return
if not isinstance(module, torch.nn.Module):
return
def _register(m):
module_name = f"{type(m).__name__}"
if module_name == "WrapName":
module_name = f"{type(m.forward_func.__self__).__name__}"
if module_name in SKIPED_MODULES:
return
forward_name = module_name + ".forward"
m.register_forward_pre_hook(
lambda m, inp: logger_inst.begin_trace(forward_name, memory=True)
)
m.register_forward_hook(
lambda m, inp, out: logger_inst.end_trace(forward_name, memory=True)
)
if not forward_only:
backward_name = module_name + ".backward"
m.register_full_backward_pre_hook(
lambda m, grad_out: logger_inst.begin_trace(backward_name, memory=True)
)
m.register_full_backward_hook(
lambda m, grad_in, grad_out: logger_inst.end_trace(
backward_name, memory=True, flush=flush
)
)
for m in module.modules():
if m in logger_inst.registered_modules:
print(
f"module `{m}` already registered, skip applying trace on same module multiple times."
)
continue
_register(m)
def _trace_autograd_function(logger_inst):
def decorator(cls):
if not issubclass(cls, torch.autograd.Function):
return cls
def _apply(name, method):
def wrapper(*args, **kwargs):
with _trace_time(name, logger_inst=logger_inst, memory=True):
result = method(*args, **kwargs)
return result
return wrapper
for attr in ["forward", "backward"]:
setattr(cls, attr, _apply(cls.__name__ + "." + attr, getattr(cls, attr)))
return cls
return decorator
def _register_optimizer_trace(
optimizer: torch.optim.Optimizer, logger_inst, flush: bool = True
):
if not logger_inst.enabled:
return
trace_name = f"{type(optimizer).__name__}.step"
if isinstance(optimizer, torch.optim.Optimizer):
optimizer.register_step_pre_hook(
lambda m, *args, **kwargs: logger_inst.begin_trace(trace_name, memory=True)
)
optimizer.register_step_post_hook(
lambda m, *args, **kwargs: logger_inst.end_trace(
trace_name, memory=True, flush=flush
)
)
elif hasattr(optimizer, "step") and callable(optimizer.step):
# customized optimzier does not has step hooks
original_step = optimizer.step
def traced_step(*args, **kwargs):
logger_inst.begin_trace(trace_name, memory=True)
result = original_step(*args, **kwargs)
logger_inst.end_trace(trace_name, memory=True, flush=flush)
return result
# Replace the step method with the new function
optimizer.step = traced_step
else:
# unknown optimizer or wrong instance pass to this function.
pass
if hasattr(optimizer, "reduce_gradients") and callable(optimizer.reduce_gradients):
trace_name = f"{type(optimizer).__name__}.reduce_gradients"
original_reduce = optimizer.reduce_gradients
def traced_reduce(*args, **kwargs):
logger_inst.begin_trace(trace_name, memory=True)
result = original_reduce(*args, **kwargs)
logger_inst.end_trace(trace_name, memory=True, flush=flush)
return result
# Replace the step method with the new function
optimizer.reduce_gradients = traced_reduce
def get_trace_api(name="nn.Module"):
"""generate module execution trace APIs for a given module name
Args:
name (str): module name
Returns:
tuple: (trace_time, register_module_trace, trace_autograd_function)
Usage of these three functions is describted in the docstring of this module
"""
_trace_logger = TraceLogger(name)
return (
partial(_trace_time, logger_inst=_trace_logger),
partial(_register_module_trace, logger_inst=_trace_logger),
partial(_trace_autograd_function, logger_inst=_trace_logger),
partial(_register_optimizer_trace, logger_inst=_trace_logger),
)

BIN
vllm/_C.abi3.so Normal file

Binary file not shown.

102
vllm/__init__.py Normal file
View File

@@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
# The version.py should be independent library, and we always import the
# version library first. Such assumption is critical for some customization.
from .version import __version__, __version_tuple__ # isort:skip
import typing
# The environment variables override should be imported before any other
# modules to ensure that the environment variables are set before any
# other modules are imported.
import vllm.env_override # noqa: F401
MODULE_ATTRS = {
"bc_linter_skip": "._bc_linter:bc_linter_skip",
"bc_linter_include": "._bc_linter:bc_linter_include",
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
"LLMEngine": ".engine.llm_engine:LLMEngine",
"LLM": ".entrypoints.llm:LLM",
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
"PromptType": ".inputs:PromptType",
"TextPrompt": ".inputs:TextPrompt",
"TokensPrompt": ".inputs:TokensPrompt",
"ModelRegistry": ".model_executor.models:ModelRegistry",
"SamplingParams": ".sampling_params:SamplingParams",
"PoolingParams": ".pooling_params:PoolingParams",
"ClassificationOutput": ".outputs:ClassificationOutput",
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
"CompletionOutput": ".outputs:CompletionOutput",
"EmbeddingOutput": ".outputs:EmbeddingOutput",
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
"PoolingOutput": ".outputs:PoolingOutput",
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
"RequestOutput": ".outputs:RequestOutput",
"ScoringOutput": ".outputs:ScoringOutput",
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
}
if typing.TYPE_CHECKING:
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput,
ClassificationRequestOutput, CompletionOutput,
EmbeddingOutput, EmbeddingRequestOutput,
PoolingOutput, PoolingRequestOutput,
RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from ._bc_linter import bc_linter_include, bc_linter_skip
else:
def __getattr__(name: str) -> typing.Any:
from importlib import import_module
if name in MODULE_ATTRS:
module_name, attr_name = MODULE_ATTRS[name].split(":")
module = import_module(module_name, __package__)
return getattr(module, attr_name)
else:
raise AttributeError(
f'module {__package__} has no attribute {name}')
__all__ = [
"__version__",
"bc_linter_skip",
"bc_linter_include",
"__version_tuple__",
"LLM",
"ModelRegistry",
"PromptType",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
"RequestOutput",
"CompletionOutput",
"PoolingOutput",
"PoolingRequestOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"ClassificationOutput",
"ClassificationRequestOutput",
"ScoringOutput",
"ScoringRequestOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

59
vllm/_bc_linter.py Normal file
View File

@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# vllm/_bc_linter.py
from __future__ import annotations
from typing import Any, Callable, TypeVar, overload
T = TypeVar("T")
@overload
def bc_linter_skip(obj: T) -> T:
...
@overload
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]:
...
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
"""
No-op decorator to mark symbols/files for BC-linter suppression.
Usage:
@bc_linter_skip
def legacy_api(...): ...
"""
def _wrap(x: T) -> T:
return x
return _wrap if obj is None else obj
@overload
def bc_linter_include(obj: T) -> T:
...
@overload
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]:
...
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
"""
Usage:
@bc_linter_include
def public_api(...): ...
"""
def _wrap(x: T) -> T:
return x
return _wrap if obj is None else obj
__all__ = ["bc_linter_skip", "bc_linter_include"]

2044
vllm/_custom_ops.py Normal file

File diff suppressed because it is too large Load Diff

393
vllm/_ipex_ops.py Normal file
View File

@@ -0,0 +1,393 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
try:
import intel_extension_for_pytorch as ipex
except ImportError as e:
logger.debug("Import error msg: %s", e.msg)
class ipex_ops:
@staticmethod
def _reshape_activation_tensor(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
num = x.size(0)
d = x.size(1) // 2
x = x.reshape(num, 2, d)
x1, x2 = torch.chunk(x, chunks=2, dim=1)
x1 = x1.reshape(num, d)
x2 = x2.reshape(num, d)
return x1, x2
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ipex.llm.functional.silu_and_mul(x, out)
@staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ipex.llm.functional.gelu_and_mul(x, out)
@staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ipex.llm.functional.gelu_and_mul(x, out)
@staticmethod
def gelu_fast(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x)
@staticmethod
def gelu_new(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(x)
@staticmethod
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
ipex.llm.functional.gelu_quick(x, out)
@staticmethod
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
ipex.llm.modules.PagedAttention.single_query_kv_attention(
out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
num_queries_per_tokens,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
ipex.llm.modules.PagedAttention.single_query_kv_attention(
out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
num_queries_per_tokens,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def rotary_embedding(
positions: torch.Tensor, # [batch_size, seq_len]
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
head_size: int,
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
is_neox: bool,
) -> None:
rot_dim = cos_sin_cache.size(1)
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
head_size, cos_sin_cache,
is_neox, rot_dim)
@staticmethod
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> torch.Tensor:
return ipex.llm.functional.rms_norm(input, weight, epsilon)
@staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
epsilon, True)
input.copy_(tmp)
@staticmethod
def varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
seqlen_q: torch.Tensor,
seqlen_k: torch.Tensor,
alibi_slopes: Optional[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k: int,
pdropout: float,
softmax_scale: float,
zero_tensors: bool,
is_causal: bool,
return_softmax: bool,
gen_: torch.Generator,
window_size_left: float,
window_size_right: float,
logits_soft_cap: float,
) -> None:
if ipex.__version__.endswith("cpu"):
if logits_soft_cap != 0.0:
raise ValueError("IPEX CPU does not support logits_soft_cap")
assert alibi_slopes is None
assert window_size_left < 0 and window_size_right < 0
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
value.contiguous(), out,
seqlen_q.int(),
seqlen_k.int(), max_seqlen_q,
max_seqlen_k, pdropout,
softmax_scale, zero_tensors,
is_causal, return_softmax,
gen_)
else: # XPU build
ipex.llm.functional.varlen_attention(
query.contiguous(), key.contiguous(), value.contiguous(), out,
seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q,
max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal,
return_softmax, gen_, window_size_left, window_size_right,
logits_soft_cap)
@staticmethod
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
assert kv_cache_dtype == "auto"
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: Optional[torch.Tensor] = None,
v_scale: Optional[torch.Tensor] = None,
k_scale_float: float = 1.0,
v_scale_float: float = 1.0,
) -> None:
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale_float, v_scale_float)
@staticmethod
def flash_attn_varlen_func(
out: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
causal: bool,
block_table: torch.Tensor,
alibi_slopes: Optional[torch.Tensor],
window_size: Optional[list[int]] = None,
softcap: Optional[float] = 0.0,
cu_seqlens_k: Optional[torch.Tensor] = None,
# The following parameters are not used in ipex kernel currently,
# we keep API compatible to CUDA's.
scheduler_metadata=None,
fa_version: int = 2,
q_descale=None,
k_descale=None,
v_descale=None,
num_splits=0,
s_aux: Optional[torch.Tensor] = None,
):
if cu_seqlens_k is None:
# cu_seqlens_k is not used in ipex kernel.
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
cu_seqlens_k = torch.cat([
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
cu_seqlens_k
]).to(torch.int32)
real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)
@staticmethod
def get_scheduler_metadata(
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads_q,
num_heads_kv,
headdim,
cache_seqlens: torch.Tensor,
qkv_dtype=torch.bfloat16,
headdim_v=None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_size: Optional[int] = None,
max_seqlen_k_new=0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
has_softcap=False,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
) -> None:
logger.warning_once(
"get_scheduler_metadata is not implemented for ipex_ops, "
"returning None.")
return None
@staticmethod
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks( # type: ignore
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
@staticmethod
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function is designed for both static and dynamic quantization:
If you provide the scale, it will use static scaling and if you omit
it, the scale will be determined dynamically. Currently, XPU platform
only supports dynamic quantization. The function also allows optional
padding of the output tensors for downstream kernels that will benefit
from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
assert scale is None, "only dynamic fp8 quantization supported on XPU"
assert not use_per_token_if_dynamic, (
"per token dynamic fp8 quantization not supported on XPU")
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
return output, scale

34
vllm/_version.py Normal file
View File

@@ -0,0 +1,34 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '0.11.0'
__version_tuple__ = version_tuple = (0, 11, 0)
__commit_id__ = commit_id = None

0
vllm/assets/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

45
vllm/assets/audio.py Normal file
View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
from urllib.parse import urljoin
import numpy.typing as npt
from vllm.utils import PlaceholderModule
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
ASSET_DIR = "multimodal_asset"
AudioAssetName = Literal["winning_call", "mary_had_lamb"]
@dataclass(frozen=True)
class AudioAsset:
name: AudioAssetName
@property
def filename(self) -> str:
return f"{self.name}.ogg"
@property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=self.filename,
s3_prefix=ASSET_DIR)
return librosa.load(audio_path, sr=None)
def get_local_path(self) -> Path:
return get_vllm_public_assets(filename=self.filename,
s3_prefix=ASSET_DIR)
@property
def url(self) -> str:
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")

41
vllm/assets/base.py Normal file
View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import lru_cache
from pathlib import Path
from typing import Optional
import vllm.envs as envs
from vllm.connections import global_http_connection
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
def get_cache_dir() -> Path:
"""Get the path to the cache for storing downloaded assets."""
path = Path(envs.VLLM_ASSETS_CACHE)
path.mkdir(parents=True, exist_ok=True)
return path
@lru_cache
def get_vllm_public_assets(filename: str,
s3_prefix: Optional[str] = None) -> Path:
"""
Download an asset file from ``s3://vllm-public-assets``
and return the path to the downloaded file.
"""
asset_directory = get_cache_dir() / "vllm_public_assets"
asset_directory.mkdir(parents=True, exist_ok=True)
asset_path = asset_directory / filename
if not asset_path.exists():
if s3_prefix is not None:
filename = s3_prefix + "/" + filename
global_http_connection.download_file(
f"{VLLM_S3_BUCKET_URL}/{filename}",
asset_path,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT)
return asset_path

50
vllm/assets/image.py Normal file
View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import torch
from PIL import Image
from .base import get_vllm_public_assets
VLM_IMAGES_DIR = "vision_model_images"
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato",
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
"Grayscale_8bits_palette_sample_image",
"1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300",
"231-200x300", "27-500x500", "17-150x600",
"handelsblatt-preview", "paper-11"]
@dataclass(frozen=True)
class ImageAsset:
name: ImageAssetName
def get_path(self, ext: str) -> Path:
"""
Return s3 path for given image.
"""
return get_vllm_public_assets(filename=f"{self.name}.{ext}",
s3_prefix=VLM_IMAGES_DIR)
@property
def pil_image(self, ext="jpg") -> Image.Image:
image_path = self.get_path(ext)
return Image.open(image_path)
@property
def image_embeds(self) -> torch.Tensor:
"""
Image embeddings, only used for testing purposes with llava 1.5.
"""
image_path = self.get_path('pt')
return torch.load(image_path, map_location="cpu", weights_only=True)
def read_bytes(self, ext: str) -> bytes:
p = Path(self.get_path(ext))
return p.read_bytes()

145
vllm/assets/video.py Normal file
View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, ClassVar, Literal, Optional
import cv2
import numpy as np
import numpy.typing as npt
from huggingface_hub import hf_hub_download
from PIL import Image
from vllm.utils import PlaceholderModule
from .base import get_cache_dir
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
@lru_cache
def download_video_asset(filename: str) -> str:
"""
Download and open an image from huggingface
repo: raushan-testing-hf/videos-test
"""
video_directory = get_cache_dir() / "video-example-data"
video_directory.mkdir(parents=True, exist_ok=True)
video_path = video_directory / filename
video_path_str = str(video_path)
if not video_path.exists():
video_path_str = hf_hub_download(
repo_id="raushan-testing-hf/videos-test",
filename=filename,
repo_type="dataset",
cache_dir=video_directory,
)
return video_path_str
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = []
num_frames = num_frames if num_frames > 0 else total_frames
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
for idx in range(total_frames):
ok = cap.grab() # next img
if not ok:
break
if idx in frame_indices: # only decompress needed
ret, frame = cap.retrieve()
if ret:
# OpenCV uses BGR format, we need to convert it to RGB
# for PIL and transformers compatibility
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
frames = np.stack(frames)
if len(frames) < num_frames:
raise ValueError(f"Could not read enough frames from video file {path}"
f" (expected {num_frames} frames, got {len(frames)})")
return frames
def video_to_pil_images_list(path: str,
num_frames: int = -1) -> list[Image.Image]:
frames = video_to_ndarrays(path, num_frames)
return [Image.fromarray(frame) for frame in frames]
def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps if fps > 0 else 0
if num_frames == -1 or num_frames > total_frames:
num_frames = total_frames
metadata = {
"total_num_frames": num_frames,
"fps": fps,
"duration": duration,
"video_backend": "opencv",
"frames_indices": list(range(num_frames)),
# extra field used to control hf processor's video
# sampling behavior
"do_sample_frames": num_frames == total_frames,
}
return metadata
VideoAssetName = Literal["baby_reading"]
@dataclass(frozen=True)
class VideoAsset:
name: VideoAssetName
num_frames: int = -1
_NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = {
"baby_reading": "sample_demo_1.mp4",
}
@property
def filename(self) -> str:
return self._NAME_TO_FILE[self.name]
@property
def video_path(self) -> str:
return download_video_asset(self.filename)
@property
def pil_images(self) -> list[Image.Image]:
ret = video_to_pil_images_list(self.video_path, self.num_frames)
return ret
@property
def np_ndarrays(self) -> npt.NDArray:
ret = video_to_ndarrays(self.video_path, self.num_frames)
return ret
@property
def metadata(self) -> dict[str, Any]:
ret = video_get_metadata(self.video_path, self.num_frames)
return ret
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
return librosa.load(self.video_path, sr=sampling_rate)[0]

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionType",
"get_attn_backend",
]

View File

View File

@@ -0,0 +1,204 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
class AttentionType:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
DECODER = "decoder"
"""Decoder attention between previous layer Q/K/V."""
ENCODER = "encoder"
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
ENCODER_ONLY = "encoder_only"
"""Encoder attention between previous layer Q/K/V."""
ENCODER_DECODER = "encoder_decoder"
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
# Whether this backend supports receiving pre-quantized query input.
# If True, the attention layer will handle query quantization instead
# of the backend, allowing torch.compile to fuse quantization with
# previous operations.
# Needs to be worked through for all backends
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False
@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
class AttentionMetadata:
pass
T = TypeVar("T", bound=AttentionMetadata)
class AttentionLayer(Protocol):
_q_scale: torch.Tensor
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float
_prob_scale: torch.Tensor
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
class AttentionImpl(ABC, Generic[T]):
# Whether the attention impl can return the softmax lse for decode.
# Some features like decode context parallelism require the softmax lse.
can_return_lse_for_decode: bool = False
# some attention backends might not always want to return lse
# even if they can return lse (for efficiency reasons)
need_to_return_lse_for_decode: bool = False
dcp_world_size: int
dcp_rank: int
def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
try:
from vllm.distributed.parallel_state import get_dcp_group
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
and self.can_return_lse_for_decode
return self
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
raise NotImplementedError
@abstractmethod
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, quant_key: QuantKey):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
:param quant_key: QuantKey object that describes the quantization op
:return: is fusion supported for this type of quantization
"""
return False
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod
def forward(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from dataclasses import dataclass
from typing import Optional
from vllm.config import ModelConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
PAD_SLOT_ID = -1
@dataclass
class MLADims:
q_lora_rank: Optional[int]
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)

645
vllm/attention/layer.py Normal file
View File

@@ -0,0 +1,645 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
try:
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment]
def check_xformers_availability():
global USE_XFORMERS_OPS
if USE_XFORMERS_OPS is not None:
return USE_XFORMERS_OPS
if current_platform.is_cuda() and current_platform.has_device_capability(
100):
# Xformers FA is not compatible with B200
USE_XFORMERS_OPS = False
else:
try:
from importlib.util import find_spec
find_spec("xformers.ops")
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
# the warning only needs to be shown once
if not USE_XFORMERS_OPS:
logger.warning("Xformers is not available, falling back.")
return USE_XFORMERS_OPS
def check_upstream_fa_availability(dtype: torch.dtype):
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
) and current_platform.has_device_capability(80):
from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available()
return False
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
use_sparse: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
**extra_impl_args,
) -> None:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
"""
super().__init__()
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, \
f"num_heads ({num_heads}) is not " \
f"divisible by num_kv_heads ({num_kv_heads})"
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.use_sparse = use_sparse
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
if attn_backend is None:
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=use_mla,
has_sink=self.has_sink,
use_sparse=use_sparse)
else:
self.attn_backend = attn_backend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]
try:
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
dtype=torch.float32)
except torch.cuda.OutOfMemoryError as e:
logger.error(
"Failed to initialize attention q/k/v range constants: %s", e)
if torch.cuda.is_available():
logger.debug("CUDA device: %s", torch.cuda.current_device())
logger.debug("Allocated: %.2f GiB",
torch.cuda.memory_allocated() / GiB_bytes)
logger.debug("Reserved: %.2f GiB",
torch.cuda.memory_reserved() / GiB_bytes)
raise RuntimeError(
"Failed to initialize q/k/v range constants. "
"This may be caused by insufficient memory to allocate "
"kv cache.") from e
# for attn backends supporting query quantization
self.query_quant = None
if self.kv_cache_dtype.startswith(
"fp8") and self.attn_backend.supports_quant_query_input:
self.query_quant = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.zeros(output_shape,
dtype=output_dtype,
device=query.device)
hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA
# backend since these tensors have different semantics and are
# processed differently.
if not self.use_mla:
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
attn_metadata,
output=output)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value,
self_kv_cache, attn_metadata)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name)
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
self.calculate_kv_scales = False
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}"
return s
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
# FlashInfer requires attention sinks to be float32
if (self.backend == _Backend.FLASHINFER
and hasattr(self.impl, 'sinks')):
from vllm.v1.attention.backends.flashinfer import FlashInferImpl
assert isinstance(self.impl, FlashInferImpl)
if (self.impl.sinks is not None
and self.impl.sinks.dtype != torch.float32):
self.impl.sinks = self.impl.sinks.to(torch.float32)
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) is not " \
f"divisible by num_kv_heads ({self.num_kv_heads})"
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Determine the attention backend
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
dtype):
backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if current_platform.is_rocm() or current_platform.is_xpu():
# currently, only torch_sdpa is supported on rocm/xpu
self.attn_backend = _Backend.TORCH_SDPA
else:
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.PALLAS,
_Backend.ROCM_AITER_FA,
_Backend.FLASH_ATTN,
} else _Backend.TORCH_SDPA
if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()):
self.attn_backend = _Backend.TORCH_SDPA
if self.attn_backend == _Backend.FLASH_ATTN:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
f"use_upstream_fa: {use_upstream_fa}")
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
if self.attn_backend == _Backend.FLASH_ATTN:
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=query.device)
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
step=kv_len,
dtype=torch.int32,
device=key.device)
out = self._flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
out = xops.memory_efficient_attention_forward(query,
key,
value,
scale=self.scale)
elif self.attn_backend == _Backend.TORCH_SDPA:
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
out = F.scaled_dot_product_attention(query,
key,
value,
scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS:
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
out = flash_attn_varlen_func(query,
key,
value,
softmax_scale=self.scale)
else:
# ViT attention hasn't supported this backend yet
raise NotImplementedError(
f"ViT attention hasn't supported {self.attn_backend} "
f"backend yet.")
return out.reshape(bsz, q_len, -1)
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
fake_impl=unified_attention_fake,
tags=tag_cudagraph_unsafe,
)
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
tags=tag_cudagraph_unsafe,
)

Some files were not shown because too many files have changed in this diff Show More