[VLM] Improving multimodal tensor hash kernel (#9008)
This commit is contained in:
committed by
GitHub
parent
c1c7dc4534
commit
0b1e04f083
@@ -17,57 +17,173 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
FMIX32_C1 = 0x85EBCA6B
|
||||||
|
FMIX32_C2 = 0xC2B2AE35
|
||||||
|
POS_C1 = 0x27D4EB2D
|
||||||
|
POS_C2 = 0x165667B1
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def hash_kernel(
|
def _rotl32(x, r: tl.constexpr):
|
||||||
input_ptr,
|
return (x << r) | (x >> (32 - r))
|
||||||
output_ptr,
|
|
||||||
n_elements,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
@triton.jit
|
||||||
PRIME: tl.constexpr,
|
def _fmix32(x, C1: tl.constexpr, C2: tl.constexpr):
|
||||||
XCONST: tl.constexpr,
|
c1 = tl.full((), C1, tl.uint32)
|
||||||
|
c2 = tl.full((), C2, tl.uint32)
|
||||||
|
x ^= x >> 16
|
||||||
|
x = x * c1
|
||||||
|
x ^= x >> 13
|
||||||
|
x = x * c2
|
||||||
|
x ^= x >> 16
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def hash_tiles32_kernel_blocked(
|
||||||
|
in_ptr,
|
||||||
|
out_ptr,
|
||||||
|
n_u32,
|
||||||
|
seed1,
|
||||||
|
seed2,
|
||||||
|
FM_C1: tl.constexpr,
|
||||||
|
FM_C2: tl.constexpr,
|
||||||
|
POS_A: tl.constexpr,
|
||||||
|
POS_B: tl.constexpr,
|
||||||
|
TILE: tl.constexpr,
|
||||||
|
BLOCK: tl.constexpr,
|
||||||
|
USE_CG: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
block_start = pid * BLOCK_SIZE
|
base = pid * TILE
|
||||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = offsets < n_elements
|
|
||||||
|
|
||||||
data = tl.load(input_ptr + offsets, mask=mask, other=0).to(tl.int64)
|
s1 = tl.full((), seed1, tl.uint32)
|
||||||
mixed = data ^ (offsets.to(tl.int64) + XCONST)
|
s2 = tl.full((), seed2, tl.uint32)
|
||||||
hash_val = mixed * PRIME
|
posA = tl.full((), POS_A, tl.uint32)
|
||||||
hash_val = hash_val ^ (hash_val >> 16)
|
posB = tl.full((), POS_B, tl.uint32)
|
||||||
hash_val = hash_val * (PRIME ^ XCONST)
|
|
||||||
hash_val = hash_val ^ (hash_val >> 13)
|
|
||||||
|
|
||||||
tl.store(output_ptr + offsets, hash_val, mask=mask)
|
h1 = tl.zeros((), dtype=tl.uint32)
|
||||||
|
h2 = tl.zeros((), dtype=tl.uint32)
|
||||||
|
|
||||||
|
for off in tl.static_range(0, TILE, BLOCK):
|
||||||
|
idx = base + off + tl.arange(0, BLOCK)
|
||||||
|
m = idx < n_u32
|
||||||
|
|
||||||
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
if USE_CG:
|
||||||
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
v = tl.load(in_ptr + idx, mask=m, other=0, cache_modifier=".cg")
|
||||||
|
else:
|
||||||
|
v = tl.load(in_ptr + idx, mask=m, other=0)
|
||||||
|
v = v.to(tl.uint32)
|
||||||
|
|
||||||
|
iu = idx.to(tl.uint32)
|
||||||
|
p1 = (iu * posA + s1) ^ _rotl32(iu, 15)
|
||||||
|
p2 = (iu * posB + s2) ^ _rotl32(iu, 13)
|
||||||
|
|
||||||
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
k1 = _fmix32(v ^ p1, C1=FM_C1, C2=FM_C2)
|
||||||
assert tensor.is_cuda
|
k2 = _fmix32(v ^ p2, C1=FM_C1, C2=FM_C2)
|
||||||
tensor = tensor.contiguous().view(torch.int32)
|
|
||||||
n = tensor.numel()
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
grid = (triton.cdiv(n, BLOCK_SIZE),)
|
|
||||||
|
|
||||||
intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
|
zero32 = tl.zeros_like(k1)
|
||||||
|
k1 = tl.where(m, k1, zero32)
|
||||||
|
k2 = tl.where(m, k2, zero32)
|
||||||
|
|
||||||
# Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
h1 += tl.sum(k1, axis=0).to(tl.uint32)
|
||||||
# Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
|
h2 += tl.sum(k2, axis=0).to(tl.uint32)
|
||||||
with torch.cuda.device(tensor.device):
|
|
||||||
hash_kernel[grid](
|
nbytes = tl.full((), n_u32 * 4, tl.uint32)
|
||||||
tensor,
|
h1 ^= nbytes
|
||||||
intermediate_hashes,
|
h2 ^= nbytes
|
||||||
n,
|
h1 = _fmix32(h1, C1=FM_C1, C2=FM_C2)
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
h2 = (
|
||||||
PRIME=PRIME_1,
|
_fmix32(h2, C1=FMIX32_C1, C2=FMIX32_C2)
|
||||||
XCONST=PRIME_2,
|
if False
|
||||||
|
else _fmix32(h2, C1=FM_C1, C2=FM_C2)
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: threads can't be synced on triton kernel
|
out = (h1.to(tl.uint64) << 32) | h2.to(tl.uint64)
|
||||||
final_hash = intermediate_hashes.sum().item()
|
tl.store(out_ptr + pid, out)
|
||||||
|
|
||||||
return final_hash
|
|
||||||
|
@triton.jit
|
||||||
|
def add_tree_reduce_u64_kernel(in_ptr, out_ptr, n_elems, CHUNK: tl.constexpr):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
start = pid * CHUNK
|
||||||
|
h = tl.zeros((), dtype=tl.uint64)
|
||||||
|
for i in tl.static_range(0, CHUNK):
|
||||||
|
idx = start + i
|
||||||
|
m = idx < n_elems
|
||||||
|
v = tl.load(in_ptr + idx, mask=m, other=0).to(tl.uint64)
|
||||||
|
h += v
|
||||||
|
tl.store(out_ptr + pid, h)
|
||||||
|
|
||||||
|
|
||||||
|
def _as_uint32_words(t: torch.Tensor) -> torch.Tensor:
|
||||||
|
assert t.is_cuda, "Use .cuda() first"
|
||||||
|
tb = t.contiguous().view(torch.uint8)
|
||||||
|
nbytes = tb.numel()
|
||||||
|
pad = (4 - (nbytes & 3)) & 3
|
||||||
|
if pad:
|
||||||
|
tb_p = torch.empty(nbytes + pad, dtype=torch.uint8, device=tb.device)
|
||||||
|
tb_p[:nbytes].copy_(tb)
|
||||||
|
tb_p[nbytes:].zero_()
|
||||||
|
tb = tb_p
|
||||||
|
return tb.view(torch.uint32)
|
||||||
|
|
||||||
|
|
||||||
|
def _final_splitmix64(x: int) -> int:
|
||||||
|
mask = (1 << 64) - 1
|
||||||
|
x &= mask
|
||||||
|
x ^= x >> 30
|
||||||
|
x = (x * 0xBF58476D1CE4E5B9) & mask
|
||||||
|
x ^= x >> 27
|
||||||
|
x = (x * 0x94D049BB133111EB) & mask
|
||||||
|
x ^= x >> 31
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def gpu_tensor_hash(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
*,
|
||||||
|
seed: int = 0x243F6A88,
|
||||||
|
tile_words: int = 8192,
|
||||||
|
block_words: int = 256,
|
||||||
|
reduce_chunk: int = 1024,
|
||||||
|
num_warps: int = 4,
|
||||||
|
num_stages: int = 4,
|
||||||
|
use_cg: bool = True,
|
||||||
|
) -> int:
|
||||||
|
assert tensor.is_cuda, "Use .cuda() first"
|
||||||
|
u32 = _as_uint32_words(tensor)
|
||||||
|
n = u32.numel()
|
||||||
|
if n == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
grid1 = (triton.cdiv(n, tile_words),)
|
||||||
|
partials = torch.empty(grid1[0], dtype=torch.uint64, device=u32.device)
|
||||||
|
hash_tiles32_kernel_blocked[grid1](
|
||||||
|
u32,
|
||||||
|
partials,
|
||||||
|
n,
|
||||||
|
seed1=seed & 0xFFFFFFFF,
|
||||||
|
seed2=((seed * 0x9E3779B1) ^ 0xDEADBEEF) & 0xFFFFFFFF,
|
||||||
|
FM_C1=FMIX32_C1,
|
||||||
|
FM_C2=FMIX32_C2,
|
||||||
|
POS_A=POS_C1,
|
||||||
|
POS_B=POS_C2,
|
||||||
|
TILE=tile_words,
|
||||||
|
BLOCK=block_words,
|
||||||
|
USE_CG=use_cg,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
|
|
||||||
|
cur = partials
|
||||||
|
while cur.numel() > 1:
|
||||||
|
n_elems = cur.numel()
|
||||||
|
grid2 = (triton.cdiv(n_elems, reduce_chunk),)
|
||||||
|
nxt = torch.empty(grid2[0], dtype=torch.uint64, device=cur.device)
|
||||||
|
add_tree_reduce_u64_kernel[grid2](cur, nxt, n_elems, CHUNK=reduce_chunk)
|
||||||
|
cur = nxt
|
||||||
|
|
||||||
|
return _final_splitmix64(int(cur.item()))
|
||||||
|
|||||||
Reference in New Issue
Block a user