71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Logits processing."""
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def hash_kernel(
|
|
input_ptr,
|
|
output_ptr,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
PRIME: tl.constexpr,
|
|
XCONST: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
|
|
data = tl.load(input_ptr + offsets, mask=mask, other=0)
|
|
mixed = data ^ (offsets + XCONST)
|
|
hash_val = mixed * PRIME
|
|
hash_val = hash_val ^ (hash_val >> 16)
|
|
hash_val = hash_val * (PRIME ^ XCONST)
|
|
hash_val = hash_val ^ (hash_val >> 13)
|
|
|
|
tl.store(output_ptr + offsets, hash_val, mask=mask)
|
|
|
|
|
|
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
|
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
|
|
|
|
|
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
|
assert tensor.is_cuda
|
|
tensor = tensor.contiguous().view(torch.int32)
|
|
n = tensor.numel()
|
|
BLOCK_SIZE = 1024
|
|
grid = (triton.cdiv(n, BLOCK_SIZE),)
|
|
|
|
intermediate_hashes = torch.empty(n, dtype=torch.int32, device=tensor.device)
|
|
|
|
hash_kernel[grid](
|
|
tensor,
|
|
intermediate_hashes,
|
|
n,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
PRIME=PRIME_1,
|
|
XCONST=PRIME_2,
|
|
)
|
|
|
|
# TODO: threads can't be synced on triton kernel
|
|
final_hash = intermediate_hashes.sum().item()
|
|
|
|
return final_hash
|