vlm: tensor hash kernel (#5974)
This commit is contained in:
70
python/sglang/srt/layers/multimodal.py
Normal file
70
python/sglang/srt/layers/multimodal.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Logits processing."""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def hash_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
PRIME: tl.constexpr,
|
||||
XCONST: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
data = tl.load(input_ptr + offsets, mask=mask, other=0)
|
||||
mixed = data ^ (offsets + XCONST)
|
||||
hash_val = mixed * PRIME
|
||||
hash_val = hash_val ^ (hash_val >> 16)
|
||||
hash_val = hash_val * (PRIME ^ XCONST)
|
||||
hash_val = hash_val ^ (hash_val >> 13)
|
||||
|
||||
tl.store(output_ptr + offsets, hash_val, mask=mask)
|
||||
|
||||
|
||||
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
||||
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
||||
|
||||
|
||||
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
||||
assert tensor.is_cuda
|
||||
tensor = tensor.contiguous().view(torch.int32)
|
||||
n = tensor.numel()
|
||||
BLOCK_SIZE = 1024
|
||||
grid = (triton.cdiv(n, BLOCK_SIZE),)
|
||||
|
||||
intermediate_hashes = torch.empty(n, dtype=torch.int32, device=tensor.device)
|
||||
|
||||
hash_kernel[grid](
|
||||
tensor,
|
||||
intermediate_hashes,
|
||||
n,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
PRIME=PRIME_1,
|
||||
XCONST=PRIME_2,
|
||||
)
|
||||
|
||||
# TODO: threads can't be synced on triton kernel
|
||||
final_hash = intermediate_hashes.sum().item()
|
||||
|
||||
return final_hash
|
||||
@@ -49,6 +49,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.disaggregation.base import BaseKVSender
|
||||
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
@@ -222,7 +223,8 @@ class MultimodalDataItem:
|
||||
for x in tensor_list
|
||||
]
|
||||
tensor = torch.concat(tensor_list)
|
||||
|
||||
if tensor.is_cuda:
|
||||
return gpu_tensor_hash(tensor)
|
||||
tensor = tensor.detach().contiguous()
|
||||
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
|
||||
Reference in New Issue
Block a user