vlm: tensor hash kernel (#5974)

This commit is contained in:
Mick
2025-05-19 06:38:16 +08:00
committed by GitHub
parent 72bfb0baf0
commit 626ccb7d3f
2 changed files with 73 additions and 1 deletions

View File

@@ -49,6 +49,7 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
@@ -222,7 +223,8 @@ class MultimodalDataItem:
for x in tensor_list
]
tensor = torch.concat(tensor_list)
if tensor.is_cuda:
return gpu_tensor_hash(tensor)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16: