[Fix] Incorrect Memory Allocation on CUDA:0 by Non-Zero CUDA Processes in TP/DP (#5745)

This commit is contained in:
yhyang201
2025-05-09 08:52:26 +08:00
committed by GitHub
parent 8dc4efd0ab
commit cec98f1034
4 changed files with 30 additions and 1 deletions

View File

@@ -48,6 +48,9 @@ class DictOutput(object):
def __getitem__(self, item): def __getitem__(self, item):
return self.__dict__[item] return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.__dict__[key] = value self.__dict__[key] = value

View File

@@ -290,6 +290,9 @@ class DictOutput(object):
def __getitem__(self, item): def __getitem__(self, item):
return self.__dict__[item] return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.__dict__[key] = value self.__dict__[key] = value

View File

@@ -8,6 +8,7 @@ from typing import List, Optional
import numpy as np import numpy as np
import PIL import PIL
import torch
from PIL import Image from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
return_tensors="pt", return_tensors="pt",
**kwargs, **kwargs,
) )
if "pixel_values" in result and isinstance(
result["pixel_values"], torch.Tensor
):
result["pixel_values"] = result["pixel_values"].to("cpu")
return result return result
@abstractmethod @abstractmethod

View File

@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
out_cache_loc: torch.Tensor = None # shape: [b], int64 out_cache_loc: torch.Tensor = None # shape: [b], int64
output_ids: torch.Tensor = None # shape: [b], int64 output_ids: torch.Tensor = None # shape: [b], int64
# For multimodal inputs
multimodal_inputs: Optional[List] = None
# The sum of all sequence lengths # The sum of all sequence lengths
seq_lens_sum: int = None seq_lens_sum: int = None
@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Copy prefix and do some basic check # Copy prefix and do some basic check
input_embeds = [] input_embeds = []
extend_input_logprob_token_ids = [] extend_input_logprob_token_ids = []
multimodal_inputs = []
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)): for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
req.req_pool_idx = req_pool_indices[i] req.req_pool_idx = req_pool_indices[i]
@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# If req.input_embeds is already a list, append its content directly # If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
multimodal_inputs.append(req.multimodal_inputs)
req.cached_tokens += pre_len - req.already_computed req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len req.already_computed = seq_len
req.is_retracted = False req.is_retracted = False
@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if input_embeds if input_embeds
else None else None
) )
for mm_input in multimodal_inputs:
if mm_input is None:
continue
for mm_item in mm_input.mm_items:
pixel_values = getattr(mm_item, "pixel_values", None)
if isinstance(pixel_values, torch.Tensor):
mm_item.pixel_values = pixel_values.to(
self.device, non_blocking=True
)
self.multimodal_inputs = multimodal_inputs
self.seq_lens_sum = sum(seq_lens) self.seq_lens_sum = sum(seq_lens)
if self.return_logprob: if self.return_logprob:
@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
self.reqs = [self.reqs[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices]
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device] self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device] self.seq_lens = self.seq_lens[keep_indices_device]
self.out_cache_loc = None self.out_cache_loc = None
@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
self.reqs.extend(other.reqs) self.reqs.extend(other.reqs)
self.multimodal_inputs.extend(other.multimodal_inputs)
self.return_logprob |= other.return_logprob self.return_logprob |= other.return_logprob
self.has_stream |= other.has_stream self.has_stream |= other.has_stream
@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens, extend_logprob_start_lens=extend_logprob_start_lens,
multimodal_inputs=[r.multimodal_inputs for r in self.reqs], multimodal_inputs=self.multimodal_inputs,
encoder_cached=self.encoder_cached, encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens, encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu, encoder_lens_cpu=self.encoder_lens_cpu,