fix: fix one more bug from merging mm_inputs (#5718)
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: XinyuanTong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -1040,15 +1040,18 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
mrope_position_delta: int,
|
||||
context_len: int,
|
||||
seq_len: int,
|
||||
) -> List[List[int]]:
|
||||
return [
|
||||
list(
|
||||
range(
|
||||
context_len + mrope_position_delta, seq_len + mrope_position_delta
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(
|
||||
[
|
||||
list(
|
||||
range(
|
||||
context_len + mrope_position_delta,
|
||||
seq_len + mrope_position_delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
for _ in range(3)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
|
||||
@@ -351,7 +351,6 @@ class MultimodalInputs:
|
||||
optional_args = [
|
||||
"mm_items",
|
||||
"image_pad_len",
|
||||
"mrope_position_delta",
|
||||
]
|
||||
for arg in optional_args:
|
||||
self_arg = getattr(self, arg, None)
|
||||
@@ -367,6 +366,14 @@ class MultimodalInputs:
|
||||
[self.mrope_positions, other.mrope_positions], dim=1
|
||||
)
|
||||
|
||||
mrope_position_delta = self.mrope_position_delta
|
||||
if mrope_position_delta is not None:
|
||||
if other.mrope_position_delta is None:
|
||||
self.mrope_position_delta = mrope_position_delta
|
||||
else:
|
||||
self.mrope_position_delta = torch.cat(
|
||||
[self.mrope_position_delta, other.mrope_position_delta], dim=0
|
||||
)
|
||||
# other args would be kept intact
|
||||
|
||||
|
||||
@@ -1455,7 +1462,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if self.model_config.is_encoder_decoder:
|
||||
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
||||
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
||||
|
||||
self.req_pool_indices = torch.cat(
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -364,23 +364,23 @@ class ForwardBatch:
|
||||
|
||||
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
||||
"""
|
||||
Merge all image inputs in the batch into a single MultiModalInputs object.
|
||||
Merge all multimodal inputs in the batch into a single MultiModalInputs object.
|
||||
|
||||
Returns:
|
||||
if none, current batch contains no image input
|
||||
if none, current batch contains no multimodal input
|
||||
|
||||
"""
|
||||
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
|
||||
return None
|
||||
|
||||
# Filter out None values
|
||||
valid_inputs = [x for x in self.mm_inputs if x is not None]
|
||||
|
||||
# Start with the first valid image input
|
||||
merged = valid_inputs[0]
|
||||
# TODO: is it expensive?
|
||||
# a workaround to avoid importing `MultimodalInputs`
|
||||
merged = valid_inputs[0].__class__(mm_items=[])
|
||||
|
||||
# Merge remaining inputs
|
||||
for mm_input in valid_inputs[1:]:
|
||||
for mm_input in valid_inputs:
|
||||
merged.merge(mm_input)
|
||||
|
||||
return merged
|
||||
@@ -407,26 +407,34 @@ class ForwardBatch:
|
||||
def _compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
||||
if self.forward_mode.is_decode():
|
||||
for i, _ in enumerate(mrope_positions_list):
|
||||
mrope_position_delta = (
|
||||
0
|
||||
if batch.multimodal_inputs[i] is None
|
||||
else batch.multimodal_inputs[i].mrope_position_delta
|
||||
# batch_size * [3 * seq_len]
|
||||
batch_size = self.seq_lens.shape[0]
|
||||
mrope_positions_list = [[]] * batch_size
|
||||
for batch_idx in range(batch_size):
|
||||
mm_input = batch.multimodal_inputs[batch_idx]
|
||||
if self.forward_mode.is_decode():
|
||||
mrope_position_deltas = (
|
||||
[0]
|
||||
if mm_input is None
|
||||
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
|
||||
)
|
||||
mrope_positions_list[i] = torch.tensor(
|
||||
MRotaryEmbedding.get_next_input_positions(
|
||||
mrope_position_delta,
|
||||
int(self.seq_lens[i]) - 1,
|
||||
int(self.seq_lens[i]),
|
||||
)
|
||||
)
|
||||
elif self.forward_mode.is_extend():
|
||||
for i, mm_input in enumerate(batch.multimodal_inputs):
|
||||
next_input_positions = []
|
||||
for mrope_position_delta in mrope_position_deltas:
|
||||
# batched deltas needs to be processed separately
|
||||
# Convert list of lists to tensor with shape [3, seq_len]
|
||||
next_input_positions += [
|
||||
MRotaryEmbedding.get_next_input_positions(
|
||||
mrope_position_delta,
|
||||
int(self.seq_lens[batch_idx]) - 1,
|
||||
int(self.seq_lens[batch_idx]),
|
||||
)
|
||||
]
|
||||
# 3 * N
|
||||
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
|
||||
elif self.forward_mode.is_extend():
|
||||
extend_seq_len, extend_prefix_len = (
|
||||
batch.extend_seq_lens[i],
|
||||
batch.extend_prefix_lens[i],
|
||||
batch.extend_seq_lens[batch_idx],
|
||||
batch.extend_prefix_lens[batch_idx],
|
||||
)
|
||||
if mm_input is None:
|
||||
# text only
|
||||
@@ -447,13 +455,12 @@ class ForwardBatch:
|
||||
:,
|
||||
extend_prefix_len : extend_prefix_len + extend_seq_len,
|
||||
]
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
mrope_positions_list[batch_idx] = mrope_positions
|
||||
|
||||
self.mrope_positions = torch.cat(
|
||||
[pos.to(device=model_runner.device) for pos in mrope_positions_list],
|
||||
dim=1,
|
||||
).to(device=model_runner.device)
|
||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||
).to(dtype=torch.int64, device=model_runner.device)
|
||||
|
||||
def get_max_chunk_capacity(self):
|
||||
# Maximum number of tokens in each chunk
|
||||
|
||||
Reference in New Issue
Block a user