diff --git a/python/sglang/srt/layers/multimodal.py b/python/sglang/srt/layers/multimodal.py new file mode 100644 index 000000000..ad313a31f --- /dev/null +++ b/python/sglang/srt/layers/multimodal.py @@ -0,0 +1,70 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Logits processing.""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def hash_kernel( + input_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + PRIME: tl.constexpr, + XCONST: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + data = tl.load(input_ptr + offsets, mask=mask, other=0) + mixed = data ^ (offsets + XCONST) + hash_val = mixed * PRIME + hash_val = hash_val ^ (hash_val >> 16) + hash_val = hash_val * (PRIME ^ XCONST) + hash_val = hash_val ^ (hash_val >> 13) + + tl.store(output_ptr + offsets, hash_val, mask=mask) + + +PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1 +PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1 + + +def gpu_tensor_hash(tensor: torch.Tensor) -> int: + assert tensor.is_cuda + tensor = tensor.contiguous().view(torch.int32) + n = tensor.numel() + BLOCK_SIZE = 1024 + grid = (triton.cdiv(n, BLOCK_SIZE),) + + intermediate_hashes = torch.empty(n, dtype=torch.int32, device=tensor.device) + + hash_kernel[grid]( + tensor, + intermediate_hashes, + n, + BLOCK_SIZE=BLOCK_SIZE, + PRIME=PRIME_1, + XCONST=PRIME_2, + ) + + # TODO: threads can't be synced on triton kernel + final_hash = intermediate_hashes.sum().item() + + return final_hash diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 836335136..10f91ed20 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -49,6 +49,7 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.disaggregation.base import BaseKVSender from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin +from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator @@ -222,7 +223,8 @@ class MultimodalDataItem: for x in tensor_list ] tensor = torch.concat(tensor_list) - + if tensor.is_cuda: + return gpu_tensor_hash(tensor) tensor = tensor.detach().contiguous() if tensor.dtype == torch.bfloat16: