diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 3f62a14d1..c7df9265d 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -221,6 +221,13 @@ class BaseMultimodalProcessor(ABC): return_tensors="pt", **kwargs, ) + # move feature tensors to cpu + for feature_name in self.FEATURE_NAMES: + if feature_name in result and isinstance( + result[feature_name], torch.Tensor + ): + result[feature_name] = result[feature_name].to("cpu") + return result @abstractmethod @@ -623,19 +630,4 @@ class BaseMultimodalProcessor(ABC): mm_token_id=mm_token_id, ) - # post-process - for item in all_collected_items: - # replace the feature tensor with a proxy - if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda: - item.feature = TransportProxyTensor( - transport_mode=self.transport_mode, data=item.feature - ) - elif ( - isinstance(item.precomputed_embeddings, torch.Tensor) - and item.precomputed_embeddings.is_cuda - ): - item.precomputed_embeddings = TransportProxyTensor( - transport_mode=self.transport_mode, data=item.precomputed_embeddings - ) - return all_collected_items, input_ids, ret