fix: fix one more bug from merging mm_inputs (#5718)
Co-authored-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: XinyuanTong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
@@ -1040,15 +1040,18 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
mrope_position_delta: int,
|
mrope_position_delta: int,
|
||||||
context_len: int,
|
context_len: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
) -> List[List[int]]:
|
) -> torch.Tensor:
|
||||||
return [
|
return torch.tensor(
|
||||||
|
[
|
||||||
list(
|
list(
|
||||||
range(
|
range(
|
||||||
context_len + mrope_position_delta, seq_len + mrope_position_delta
|
context_len + mrope_position_delta,
|
||||||
|
seq_len + mrope_position_delta,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for _ in range(3)
|
for _ in range(3)
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||||
|
|||||||
@@ -351,7 +351,6 @@ class MultimodalInputs:
|
|||||||
optional_args = [
|
optional_args = [
|
||||||
"mm_items",
|
"mm_items",
|
||||||
"image_pad_len",
|
"image_pad_len",
|
||||||
"mrope_position_delta",
|
|
||||||
]
|
]
|
||||||
for arg in optional_args:
|
for arg in optional_args:
|
||||||
self_arg = getattr(self, arg, None)
|
self_arg = getattr(self, arg, None)
|
||||||
@@ -367,6 +366,14 @@ class MultimodalInputs:
|
|||||||
[self.mrope_positions, other.mrope_positions], dim=1
|
[self.mrope_positions, other.mrope_positions], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mrope_position_delta = self.mrope_position_delta
|
||||||
|
if mrope_position_delta is not None:
|
||||||
|
if other.mrope_position_delta is None:
|
||||||
|
self.mrope_position_delta = mrope_position_delta
|
||||||
|
else:
|
||||||
|
self.mrope_position_delta = torch.cat(
|
||||||
|
[self.mrope_position_delta, other.mrope_position_delta], dim=0
|
||||||
|
)
|
||||||
# other args would be kept intact
|
# other args would be kept intact
|
||||||
|
|
||||||
|
|
||||||
@@ -1455,7 +1462,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
if self.model_config.is_encoder_decoder:
|
if self.model_config.is_encoder_decoder:
|
||||||
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
|
||||||
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
|
||||||
|
|
||||||
self.req_pool_indices = torch.cat(
|
self.req_pool_indices = torch.cat(
|
||||||
[self.req_pool_indices, other.req_pool_indices]
|
[self.req_pool_indices, other.req_pool_indices]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from sglang.srt.utils import get_compiler_backend
|
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
@@ -364,23 +364,23 @@ class ForwardBatch:
|
|||||||
|
|
||||||
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
||||||
"""
|
"""
|
||||||
Merge all image inputs in the batch into a single MultiModalInputs object.
|
Merge all multimodal inputs in the batch into a single MultiModalInputs object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
if none, current batch contains no image input
|
if none, current batch contains no multimodal input
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
|
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Filter out None values
|
# Filter out None values
|
||||||
valid_inputs = [x for x in self.mm_inputs if x is not None]
|
valid_inputs = [x for x in self.mm_inputs if x is not None]
|
||||||
|
|
||||||
# Start with the first valid image input
|
# TODO: is it expensive?
|
||||||
merged = valid_inputs[0]
|
# a workaround to avoid importing `MultimodalInputs`
|
||||||
|
merged = valid_inputs[0].__class__(mm_items=[])
|
||||||
|
|
||||||
# Merge remaining inputs
|
# Merge remaining inputs
|
||||||
for mm_input in valid_inputs[1:]:
|
for mm_input in valid_inputs:
|
||||||
merged.merge(mm_input)
|
merged.merge(mm_input)
|
||||||
|
|
||||||
return merged
|
return merged
|
||||||
@@ -407,26 +407,34 @@ class ForwardBatch:
|
|||||||
def _compute_mrope_positions(
|
def _compute_mrope_positions(
|
||||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||||
):
|
):
|
||||||
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
# batch_size * [3 * seq_len]
|
||||||
|
batch_size = self.seq_lens.shape[0]
|
||||||
|
mrope_positions_list = [[]] * batch_size
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
mm_input = batch.multimodal_inputs[batch_idx]
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
for i, _ in enumerate(mrope_positions_list):
|
mrope_position_deltas = (
|
||||||
mrope_position_delta = (
|
[0]
|
||||||
0
|
if mm_input is None
|
||||||
if batch.multimodal_inputs[i] is None
|
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
|
||||||
else batch.multimodal_inputs[i].mrope_position_delta
|
|
||||||
)
|
)
|
||||||
mrope_positions_list[i] = torch.tensor(
|
next_input_positions = []
|
||||||
|
for mrope_position_delta in mrope_position_deltas:
|
||||||
|
# batched deltas needs to be processed separately
|
||||||
|
# Convert list of lists to tensor with shape [3, seq_len]
|
||||||
|
next_input_positions += [
|
||||||
MRotaryEmbedding.get_next_input_positions(
|
MRotaryEmbedding.get_next_input_positions(
|
||||||
mrope_position_delta,
|
mrope_position_delta,
|
||||||
int(self.seq_lens[i]) - 1,
|
int(self.seq_lens[batch_idx]) - 1,
|
||||||
int(self.seq_lens[i]),
|
int(self.seq_lens[batch_idx]),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
# 3 * N
|
||||||
|
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
|
||||||
elif self.forward_mode.is_extend():
|
elif self.forward_mode.is_extend():
|
||||||
for i, mm_input in enumerate(batch.multimodal_inputs):
|
|
||||||
extend_seq_len, extend_prefix_len = (
|
extend_seq_len, extend_prefix_len = (
|
||||||
batch.extend_seq_lens[i],
|
batch.extend_seq_lens[batch_idx],
|
||||||
batch.extend_prefix_lens[i],
|
batch.extend_prefix_lens[batch_idx],
|
||||||
)
|
)
|
||||||
if mm_input is None:
|
if mm_input is None:
|
||||||
# text only
|
# text only
|
||||||
@@ -447,13 +455,12 @@ class ForwardBatch:
|
|||||||
:,
|
:,
|
||||||
extend_prefix_len : extend_prefix_len + extend_seq_len,
|
extend_prefix_len : extend_prefix_len + extend_seq_len,
|
||||||
]
|
]
|
||||||
mrope_positions_list[i] = mrope_positions
|
mrope_positions_list[batch_idx] = mrope_positions
|
||||||
|
|
||||||
self.mrope_positions = torch.cat(
|
self.mrope_positions = torch.cat(
|
||||||
[pos.to(device=model_runner.device) for pos in mrope_positions_list],
|
[pos.to(device=model_runner.device) for pos in mrope_positions_list],
|
||||||
dim=1,
|
dim=1,
|
||||||
).to(device=model_runner.device)
|
).to(dtype=torch.int64, device=model_runner.device)
|
||||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
|
||||||
|
|
||||||
def get_max_chunk_capacity(self):
|
def get_max_chunk_capacity(self):
|
||||||
# Maximum number of tokens in each chunk
|
# Maximum number of tokens in each chunk
|
||||||
|
|||||||
@@ -307,7 +307,6 @@ class TestOpenAIVisionServer(CustomTestCase):
|
|||||||
self.assertGreater(len(video_response), 0)
|
self.assertGreater(len(video_response), 0)
|
||||||
|
|
||||||
def test_regex(self):
|
def test_regex(self):
|
||||||
return
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
regex = (
|
regex = (
|
||||||
|
|||||||
Reference in New Issue
Block a user