[Fix] Incorrect Memory Allocation on CUDA:0 by Non-Zero CUDA Processes in TP/DP (#5745)
This commit is contained in:
@@ -48,6 +48,9 @@ class DictOutput(object):
|
|||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.__dict__[item]
|
return self.__dict__[item]
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self.__dict__
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
|
|
||||||
|
|||||||
@@ -290,6 +290,9 @@ class DictOutput(object):
|
|||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.__dict__[item]
|
return self.__dict__[item]
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self.__dict__
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import BaseImageProcessorFast
|
from transformers import BaseImageProcessorFast
|
||||||
|
|
||||||
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
if "pixel_values" in result and isinstance(
|
||||||
|
result["pixel_values"], torch.Tensor
|
||||||
|
):
|
||||||
|
result["pixel_values"] = result["pixel_values"].to("cpu")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
||||||
output_ids: torch.Tensor = None # shape: [b], int64
|
output_ids: torch.Tensor = None # shape: [b], int64
|
||||||
|
|
||||||
|
# For multimodal inputs
|
||||||
|
multimodal_inputs: Optional[List] = None
|
||||||
|
|
||||||
# The sum of all sequence lengths
|
# The sum of all sequence lengths
|
||||||
seq_lens_sum: int = None
|
seq_lens_sum: int = None
|
||||||
|
|
||||||
@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# Copy prefix and do some basic check
|
# Copy prefix and do some basic check
|
||||||
input_embeds = []
|
input_embeds = []
|
||||||
extend_input_logprob_token_ids = []
|
extend_input_logprob_token_ids = []
|
||||||
|
multimodal_inputs = []
|
||||||
|
|
||||||
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
|
for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
|
||||||
req.req_pool_idx = req_pool_indices[i]
|
req.req_pool_idx = req_pool_indices[i]
|
||||||
@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# If req.input_embeds is already a list, append its content directly
|
# If req.input_embeds is already a list, append its content directly
|
||||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||||
|
|
||||||
|
multimodal_inputs.append(req.multimodal_inputs)
|
||||||
|
|
||||||
req.cached_tokens += pre_len - req.already_computed
|
req.cached_tokens += pre_len - req.already_computed
|
||||||
req.already_computed = seq_len
|
req.already_computed = seq_len
|
||||||
req.is_retracted = False
|
req.is_retracted = False
|
||||||
@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
if input_embeds
|
if input_embeds
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
for mm_input in multimodal_inputs:
|
||||||
|
if mm_input is None:
|
||||||
|
continue
|
||||||
|
for mm_item in mm_input.mm_items:
|
||||||
|
pixel_values = getattr(mm_item, "pixel_values", None)
|
||||||
|
if isinstance(pixel_values, torch.Tensor):
|
||||||
|
mm_item.pixel_values = pixel_values.to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
self.multimodal_inputs = multimodal_inputs
|
||||||
self.seq_lens_sum = sum(seq_lens)
|
self.seq_lens_sum = sum(seq_lens)
|
||||||
|
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
|
||||||
|
|
||||||
self.reqs = [self.reqs[i] for i in keep_indices]
|
self.reqs = [self.reqs[i] for i in keep_indices]
|
||||||
|
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
||||||
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||||
self.seq_lens = self.seq_lens[keep_indices_device]
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||||
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
|
||||||
self.reqs.extend(other.reqs)
|
self.reqs.extend(other.reqs)
|
||||||
|
self.multimodal_inputs.extend(other.multimodal_inputs)
|
||||||
|
|
||||||
self.return_logprob |= other.return_logprob
|
self.return_logprob |= other.return_logprob
|
||||||
self.has_stream |= other.has_stream
|
self.has_stream |= other.has_stream
|
||||||
@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
extend_seq_lens=extend_seq_lens,
|
extend_seq_lens=extend_seq_lens,
|
||||||
extend_prefix_lens=extend_prefix_lens,
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||||
multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
|
multimodal_inputs=self.multimodal_inputs,
|
||||||
encoder_cached=self.encoder_cached,
|
encoder_cached=self.encoder_cached,
|
||||||
encoder_lens=self.encoder_lens,
|
encoder_lens=self.encoder_lens,
|
||||||
encoder_lens_cpu=self.encoder_lens_cpu,
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
||||||
|
|||||||
Reference in New Issue
Block a user