vlm: tensor hash kernel (#5974)
This commit is contained in:
70
python/sglang/srt/layers/multimodal.py
Normal file
70
python/sglang/srt/layers/multimodal.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Logits processing."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def hash_kernel(
|
||||||
|
input_ptr,
|
||||||
|
output_ptr,
|
||||||
|
n_elements,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
PRIME: tl.constexpr,
|
||||||
|
XCONST: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
|
||||||
|
data = tl.load(input_ptr + offsets, mask=mask, other=0)
|
||||||
|
mixed = data ^ (offsets + XCONST)
|
||||||
|
hash_val = mixed * PRIME
|
||||||
|
hash_val = hash_val ^ (hash_val >> 16)
|
||||||
|
hash_val = hash_val * (PRIME ^ XCONST)
|
||||||
|
hash_val = hash_val ^ (hash_val >> 13)
|
||||||
|
|
||||||
|
tl.store(output_ptr + offsets, hash_val, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
||||||
|
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
||||||
|
assert tensor.is_cuda
|
||||||
|
tensor = tensor.contiguous().view(torch.int32)
|
||||||
|
n = tensor.numel()
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
grid = (triton.cdiv(n, BLOCK_SIZE),)
|
||||||
|
|
||||||
|
intermediate_hashes = torch.empty(n, dtype=torch.int32, device=tensor.device)
|
||||||
|
|
||||||
|
hash_kernel[grid](
|
||||||
|
tensor,
|
||||||
|
intermediate_hashes,
|
||||||
|
n,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
PRIME=PRIME_1,
|
||||||
|
XCONST=PRIME_2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: threads can't be synced on triton kernel
|
||||||
|
final_hash = intermediate_hashes.sum().item()
|
||||||
|
|
||||||
|
return final_hash
|
||||||
@@ -49,6 +49,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
|||||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
from sglang.srt.disaggregation.base import BaseKVSender
|
from sglang.srt.disaggregation.base import BaseKVSender
|
||||||
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
||||||
|
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
@@ -222,7 +223,8 @@ class MultimodalDataItem:
|
|||||||
for x in tensor_list
|
for x in tensor_list
|
||||||
]
|
]
|
||||||
tensor = torch.concat(tensor_list)
|
tensor = torch.concat(tensor_list)
|
||||||
|
if tensor.is_cuda:
|
||||||
|
return gpu_tensor_hash(tensor)
|
||||||
tensor = tensor.detach().contiguous()
|
tensor = tensor.detach().contiguous()
|
||||||
|
|
||||||
if tensor.dtype == torch.bfloat16:
|
if tensor.dtype == torch.bfloat16:
|
||||||
|
|||||||
Reference in New Issue
Block a user