[Fix] Incorrect Memory Allocation on CUDA:0 by Non-Zero CUDA Processes in TP/DP (#5745)
This commit is contained in:
@@ -48,6 +48,9 @@ class DictOutput(object):
|
||||
def __getitem__(self, item):
|
||||
return self.__dict__[item]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
|
||||
@@ -290,6 +290,9 @@ class DictOutput(object):
|
||||
def __getitem__(self, item):
|
||||
return self.__dict__[item]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.__dict__
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
return_tensors="pt",
|
||||
**kwargs,
|
||||
)
|
||||
if "pixel_values" in result and isinstance(
|
||||
result["pixel_values"], torch.Tensor
|
||||
):
|
||||
result["pixel_values"] = result["pixel_values"].to("cpu")
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
||||
output_ids: torch.Tensor = None # shape: [b], int64
|
||||
|
||||
# For multimodal inputs
|
||||
multimodal_inputs: Optional[List] = None
|
||||
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int = None
|
||||
|
||||
@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# Copy prefix and do some basic check
|
||||
input_embeds = []
|
||||
extend_input_logprob_token_ids = []
|
||||
multimodal_inputs = []
|
||||
|
||||
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
|
||||
req.req_pool_idx = req_pool_indices[i]
|
||||
@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# If req.input_embeds is already a list, append its content directly
|
||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||
|
||||
multimodal_inputs.append(req.multimodal_inputs)
|
||||
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if input_embeds
|
||||
else None
|
||||
)
|
||||
for mm_input in multimodal_inputs:
|
||||
if mm_input is None:
|
||||
continue
|
||||
for mm_item in mm_input.mm_items:
|
||||
pixel_values = getattr(mm_item, "pixel_values", None)
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
mm_item.pixel_values = pixel_values.to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.multimodal_inputs = multimodal_inputs
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
|
||||
if self.return_logprob:
|
||||
@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||
|
||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
||||
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
||||
self.reqs.extend(other.reqs)
|
||||
self.multimodal_inputs.extend(other.multimodal_inputs)
|
||||
|
||||
self.return_logprob |= other.return_logprob
|
||||
self.has_stream |= other.has_stream
|
||||
@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||
multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
|
||||
multimodal_inputs=self.multimodal_inputs,
|
||||
encoder_cached=self.encoder_cached,
|
||||
encoder_lens=self.encoder_lens,
|
||||
encoder_lens_cpu=self.encoder_lens_cpu,
|
||||
|
||||
Reference in New Issue
Block a user