diff --git a/python/sglang/srt/layers/multimodal.py b/python/sglang/srt/layers/multimodal.py index 0ddd567c0..738c65830 100644 --- a/python/sglang/srt/layers/multimodal.py +++ b/python/sglang/srt/layers/multimodal.py @@ -17,57 +17,173 @@ import torch import triton import triton.language as tl +FMIX32_C1 = 0x85EBCA6B +FMIX32_C2 = 0xC2B2AE35 +POS_C1 = 0x27D4EB2D +POS_C2 = 0x165667B1 + @triton.jit -def hash_kernel( - input_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - PRIME: tl.constexpr, - XCONST: tl.constexpr, +def _rotl32(x, r: tl.constexpr): + return (x << r) | (x >> (32 - r)) + + +@triton.jit +def _fmix32(x, C1: tl.constexpr, C2: tl.constexpr): + c1 = tl.full((), C1, tl.uint32) + c2 = tl.full((), C2, tl.uint32) + x ^= x >> 16 + x = x * c1 + x ^= x >> 13 + x = x * c2 + x ^= x >> 16 + return x + + +@triton.jit +def hash_tiles32_kernel_blocked( + in_ptr, + out_ptr, + n_u32, + seed1, + seed2, + FM_C1: tl.constexpr, + FM_C2: tl.constexpr, + POS_A: tl.constexpr, + POS_B: tl.constexpr, + TILE: tl.constexpr, + BLOCK: tl.constexpr, + USE_CG: tl.constexpr, ): pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements + base = pid * TILE - data = tl.load(input_ptr + offsets, mask=mask, other=0).to(tl.int64) - mixed = data ^ (offsets.to(tl.int64) + XCONST) - hash_val = mixed * PRIME - hash_val = hash_val ^ (hash_val >> 16) - hash_val = hash_val * (PRIME ^ XCONST) - hash_val = hash_val ^ (hash_val >> 13) + s1 = tl.full((), seed1, tl.uint32) + s2 = tl.full((), seed2, tl.uint32) + posA = tl.full((), POS_A, tl.uint32) + posB = tl.full((), POS_B, tl.uint32) - tl.store(output_ptr + offsets, hash_val, mask=mask) + h1 = tl.zeros((), dtype=tl.uint32) + h2 = tl.zeros((), dtype=tl.uint32) + + for off in tl.static_range(0, TILE, BLOCK): + idx = base + off + tl.arange(0, BLOCK) + m = idx < n_u32 + + if USE_CG: + v = tl.load(in_ptr + idx, mask=m, other=0, cache_modifier=".cg") + else: + v = tl.load(in_ptr + idx, mask=m, other=0) + v = v.to(tl.uint32) + + iu = idx.to(tl.uint32) + p1 = (iu * posA + s1) ^ _rotl32(iu, 15) + p2 = (iu * posB + s2) ^ _rotl32(iu, 13) + + k1 = _fmix32(v ^ p1, C1=FM_C1, C2=FM_C2) + k2 = _fmix32(v ^ p2, C1=FM_C1, C2=FM_C2) + + zero32 = tl.zeros_like(k1) + k1 = tl.where(m, k1, zero32) + k2 = tl.where(m, k2, zero32) + + h1 += tl.sum(k1, axis=0).to(tl.uint32) + h2 += tl.sum(k2, axis=0).to(tl.uint32) + + nbytes = tl.full((), n_u32 * 4, tl.uint32) + h1 ^= nbytes + h2 ^= nbytes + h1 = _fmix32(h1, C1=FM_C1, C2=FM_C2) + h2 = ( + _fmix32(h2, C1=FMIX32_C1, C2=FMIX32_C2) + if False + else _fmix32(h2, C1=FM_C1, C2=FM_C2) + ) + + out = (h1.to(tl.uint64) << 32) | h2.to(tl.uint64) + tl.store(out_ptr + pid, out) -PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1 -PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1 +@triton.jit +def add_tree_reduce_u64_kernel(in_ptr, out_ptr, n_elems, CHUNK: tl.constexpr): + pid = tl.program_id(axis=0) + start = pid * CHUNK + h = tl.zeros((), dtype=tl.uint64) + for i in tl.static_range(0, CHUNK): + idx = start + i + m = idx < n_elems + v = tl.load(in_ptr + idx, mask=m, other=0).to(tl.uint64) + h += v + tl.store(out_ptr + pid, h) -def gpu_tensor_hash(tensor: torch.Tensor) -> int: - assert tensor.is_cuda - tensor = tensor.contiguous().view(torch.int32) - n = tensor.numel() - BLOCK_SIZE = 1024 - grid = (triton.cdiv(n, BLOCK_SIZE),) +def _as_uint32_words(t: torch.Tensor) -> torch.Tensor: + assert t.is_cuda, "Use .cuda() first" + tb = t.contiguous().view(torch.uint8) + nbytes = tb.numel() + pad = (4 - (nbytes & 3)) & 3 + if pad: + tb_p = torch.empty(nbytes + pad, dtype=torch.uint8, device=tb.device) + tb_p[:nbytes].copy_(tb) + tb_p[nbytes:].zero_() + tb = tb_p + return tb.view(torch.uint32) - intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device) - # Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) - # Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579 - with torch.cuda.device(tensor.device): - hash_kernel[grid]( - tensor, - intermediate_hashes, - n, - BLOCK_SIZE=BLOCK_SIZE, - PRIME=PRIME_1, - XCONST=PRIME_2, - ) +def _final_splitmix64(x: int) -> int: + mask = (1 << 64) - 1 + x &= mask + x ^= x >> 30 + x = (x * 0xBF58476D1CE4E5B9) & mask + x ^= x >> 27 + x = (x * 0x94D049BB133111EB) & mask + x ^= x >> 31 + return x - # TODO: threads can't be synced on triton kernel - final_hash = intermediate_hashes.sum().item() - return final_hash +@torch.inference_mode() +def gpu_tensor_hash( + tensor: torch.Tensor, + *, + seed: int = 0x243F6A88, + tile_words: int = 8192, + block_words: int = 256, + reduce_chunk: int = 1024, + num_warps: int = 4, + num_stages: int = 4, + use_cg: bool = True, +) -> int: + assert tensor.is_cuda, "Use .cuda() first" + u32 = _as_uint32_words(tensor) + n = u32.numel() + if n == 0: + return 0 + + grid1 = (triton.cdiv(n, tile_words),) + partials = torch.empty(grid1[0], dtype=torch.uint64, device=u32.device) + hash_tiles32_kernel_blocked[grid1]( + u32, + partials, + n, + seed1=seed & 0xFFFFFFFF, + seed2=((seed * 0x9E3779B1) ^ 0xDEADBEEF) & 0xFFFFFFFF, + FM_C1=FMIX32_C1, + FM_C2=FMIX32_C2, + POS_A=POS_C1, + POS_B=POS_C2, + TILE=tile_words, + BLOCK=block_words, + USE_CG=use_cg, + num_warps=num_warps, + num_stages=num_stages, + ) + + cur = partials + while cur.numel() > 1: + n_elems = cur.numel() + grid2 = (triton.cdiv(n_elems, reduce_chunk),) + nxt = torch.empty(grid2[0], dtype=torch.uint64, device=cur.device) + add_tree_reduce_u64_kernel[grid2](cur, nxt, n_elems, CHUNK=reduce_chunk) + cur = nxt + + return _final_splitmix64(int(cur.item()))