From cec98f1034dcc77d1c0e2fd9e2363a6594499029 Mon Sep 17 00:00:00 2001 From: yhyang201 <47235274+yhyang201@users.noreply.github.com> Date: Fri, 9 May 2025 08:52:26 +0800 Subject: [PATCH] [Fix] Incorrect Memory Allocation on CUDA:0 by Non-Zero CUDA Processes in TP/DP (#5745) --- python/sglang/srt/configs/deepseekvl2.py | 3 +++ python/sglang/srt/configs/janus_pro.py | 3 +++ .../multimodal_processors/base_processor.py | 5 +++++ python/sglang/srt/managers/schedule_batch.py | 20 ++++++++++++++++++- 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/configs/deepseekvl2.py b/python/sglang/srt/configs/deepseekvl2.py index 9d2f3de59..961995410 100644 --- a/python/sglang/srt/configs/deepseekvl2.py +++ b/python/sglang/srt/configs/deepseekvl2.py @@ -48,6 +48,9 @@ class DictOutput(object): def __getitem__(self, item): return self.__dict__[item] + def __contains__(self, key): + return key in self.__dict__ + def __setitem__(self, key, value): self.__dict__[key] = value diff --git a/python/sglang/srt/configs/janus_pro.py b/python/sglang/srt/configs/janus_pro.py index ad254edc4..b7bbfd71e 100644 --- a/python/sglang/srt/configs/janus_pro.py +++ b/python/sglang/srt/configs/janus_pro.py @@ -290,6 +290,9 @@ class DictOutput(object): def __getitem__(self, item): return self.__dict__[item] + def __contains__(self, key): + return key in self.__dict__ + def __setitem__(self, key, value): self.__dict__[key] = value diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index 72960d310..aafa63c3b 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -8,6 +8,7 @@ from typing import List, Optional import numpy as np import PIL +import torch from PIL import Image from transformers import BaseImageProcessorFast @@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC): return_tensors="pt", **kwargs, ) + if "pixel_values" in result and isinstance( + result["pixel_values"], torch.Tensor + ): + result["pixel_values"] = result["pixel_values"].to("cpu") return result @abstractmethod diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 58d5637dd..ee9f40719 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): out_cache_loc: torch.Tensor = None # shape: [b], int64 output_ids: torch.Tensor = None # shape: [b], int64 + # For multimodal inputs + multimodal_inputs: Optional[List] = None + # The sum of all sequence lengths seq_lens_sum: int = None @@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Copy prefix and do some basic check input_embeds = [] extend_input_logprob_token_ids = [] + multimodal_inputs = [] for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): req.req_pool_idx = req_pool_indices[i] @@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # If req.input_embeds is already a list, append its content directly input_embeds.extend(req.input_embeds) # Use extend to avoid nesting + multimodal_inputs.append(req.multimodal_inputs) + req.cached_tokens += pre_len - req.already_computed req.already_computed = seq_len req.is_retracted = False @@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): if input_embeds else None ) + for mm_input in multimodal_inputs: + if mm_input is None: + continue + for mm_item in mm_input.mm_items: + pixel_values = getattr(mm_item, "pixel_values", None) + if isinstance(pixel_values, torch.Tensor): + mm_item.pixel_values = pixel_values.to( + self.device, non_blocking=True + ) + self.multimodal_inputs = multimodal_inputs self.seq_lens_sum = sum(seq_lens) if self.return_logprob: @@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices] + self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices] self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device] self.out_cache_loc = None @@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs self.reqs.extend(other.reqs) + self.multimodal_inputs.extend(other.multimodal_inputs) self.return_logprob |= other.return_logprob self.has_stream |= other.has_stream @@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): extend_seq_lens=extend_seq_lens, extend_prefix_lens=extend_prefix_lens, extend_logprob_start_lens=extend_logprob_start_lens, - multimodal_inputs=[r.multimodal_inputs for r in self.reqs], + multimodal_inputs=self.multimodal_inputs, encoder_cached=self.encoder_cached, encoder_lens=self.encoder_lens, encoder_lens_cpu=self.encoder_lens_cpu,