101 lines
3.1 KiB
Python
101 lines
3.1 KiB
Python
import dataclasses
|
|
|
|
from triton.debugger import torch_wrapper
|
|
|
|
torch = torch_wrapper.torch
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class RegisteredStorage:
|
|
storage: torch.Storage
|
|
dtype: torch.dtype
|
|
size: int
|
|
ptr: int
|
|
|
|
@property
|
|
def end_ptr(self) -> int:
|
|
return self.ptr + self.size
|
|
|
|
@property
|
|
def access_tensor(self) -> torch.Tensor:
|
|
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
|
|
|
|
def ensure_immutable(self):
|
|
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
|
|
|
|
|
|
class MemoryMap:
|
|
storages: [RegisteredStorage]
|
|
|
|
def __init__(self):
|
|
self.storages = []
|
|
|
|
def _get_registered_storage(self, pointer: torch.Tensor):
|
|
max_pointer = torch.max(pointer).item()
|
|
min_pointer = torch.min(pointer).item()
|
|
|
|
registered_storage = next(
|
|
filter(
|
|
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
|
|
),
|
|
None,
|
|
)
|
|
if registered_storage is None:
|
|
raise Exception("Storage not found or pointers spanning multiple tensors")
|
|
registered_storage.ensure_immutable()
|
|
return registered_storage
|
|
|
|
def add_tensor(self, t: torch.Tensor):
|
|
storage = t.untyped_storage()
|
|
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
|
|
return t.data_ptr()
|
|
|
|
def load(
|
|
self,
|
|
pointer: torch.Tensor,
|
|
mask: torch.Tensor = None,
|
|
other=0.0,
|
|
):
|
|
assert pointer.is_cuda
|
|
assert 0 < pointer.dim() < 3
|
|
assert pointer.dtype == torch.int64
|
|
|
|
if mask is None:
|
|
mask = torch.ones_like(pointer).bool()
|
|
assert mask.is_cuda
|
|
assert 0 < mask.dim() < 3
|
|
assert mask.dtype == torch.bool
|
|
mask = mask.expand(pointer.size())
|
|
|
|
if torch.all(~mask):
|
|
# Todo: The type is wrong here, we can't determine the correct type
|
|
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
|
|
|
|
registered_storage = self._get_registered_storage(pointer[mask])
|
|
access_tensor = registered_storage.access_tensor
|
|
|
|
index_tensor = pointer - registered_storage.ptr
|
|
|
|
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
|
|
block[mask] = access_tensor[index_tensor[mask]]
|
|
return block
|
|
|
|
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
|
|
assert 0 < pointer.dim() < 3
|
|
assert pointer.dtype == torch.int64
|
|
|
|
if mask is None:
|
|
mask = torch.ones_like(pointer).bool()
|
|
assert 0 < mask.dim() < 3
|
|
assert mask.dtype == torch.bool
|
|
mask = mask.expand(pointer.size())
|
|
|
|
if torch.all(~mask):
|
|
return
|
|
|
|
registered_storage = self._get_registered_storage(pointer[mask])
|
|
access_tensor = registered_storage.access_tensor
|
|
|
|
index_tensor = pointer - registered_storage.ptr
|
|
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)
|