vlm: tensor hash kernel (#5974)
This commit is contained in:
@@ -49,6 +49,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.disaggregation.base import BaseKVSender
|
||||
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
@@ -222,7 +223,8 @@ class MultimodalDataItem:
|
||||
for x in tensor_list
|
||||
]
|
||||
tensor = torch.concat(tensor_list)
|
||||
|
||||
if tensor.is_cuda:
|
||||
return gpu_tensor_hash(tensor)
|
||||
tensor = tensor.detach().contiguous()
|
||||
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
|
||||
Reference in New Issue
Block a user