45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from functools import lru_cache
|
|
from typing import Any, Iterable
|
|
|
|
|
|
@lru_cache()
|
|
def _prepare_for_load() -> str:
|
|
import os
|
|
import warnings
|
|
|
|
warnings.filterwarnings(
|
|
"ignore", category=UserWarning, module="torch.utils.cpp_extension"
|
|
)
|
|
return os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
@lru_cache()
|
|
def load_kernel_module(
|
|
path: str | Iterable[str],
|
|
name: str,
|
|
*,
|
|
build: str = "build",
|
|
cflags: Iterable[str] | None = None,
|
|
cuda_flags: Iterable[str] | None = None,
|
|
ldflags: Iterable[str] | None = None,
|
|
) -> Any:
|
|
from torch.utils.cpp_extension import load
|
|
|
|
if isinstance(path, str):
|
|
path = (path,)
|
|
|
|
abs_path = _prepare_for_load()
|
|
build_dir = f"{abs_path}/{build}"
|
|
os.makedirs(build_dir, exist_ok=True)
|
|
return load(
|
|
name=name,
|
|
sources=[f"{abs_path}/csrc/{p}" for p in path],
|
|
extra_cflags=list(cflags or []) or ["-O3", "-std=c++17"],
|
|
extra_cuda_cflags=list(cuda_flags or []) or ["-O3", "-std=c++17"],
|
|
extra_ldflags=list(ldflags or []) or None,
|
|
build_directory=build_dir,
|
|
)
|