[Bugfix] qwen2vl forward_extend (#1727)
This commit is contained in:
@@ -35,7 +35,6 @@ from dataclasses import dataclass
|
|||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||||
@@ -134,16 +133,23 @@ class ForwardBatch:
|
|||||||
)
|
)
|
||||||
elif self.forward_mode.is_extend():
|
elif self.forward_mode.is_extend():
|
||||||
for i, image_inputs in enumerate(batch.image_inputs):
|
for i, image_inputs in enumerate(batch.image_inputs):
|
||||||
|
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||||
|
self.extend_start_loc[i],
|
||||||
|
self.extend_seq_lens[i],
|
||||||
|
self.extend_prefix_lens[i],
|
||||||
|
)
|
||||||
if image_inputs is None:
|
if image_inputs is None:
|
||||||
# text only
|
# text only
|
||||||
mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3
|
mrope_positions = [
|
||||||
|
[
|
||||||
|
pos
|
||||||
|
for pos in range(
|
||||||
|
extend_prefix_len, extend_prefix_len + extend_seq_len
|
||||||
|
)
|
||||||
|
]
|
||||||
|
] * 3
|
||||||
mrope_position_delta = 0
|
mrope_position_delta = 0
|
||||||
else:
|
else:
|
||||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
|
||||||
self.extend_start_loc[i],
|
|
||||||
self.extend_seq_lens[i],
|
|
||||||
self.extend_prefix_lens[i],
|
|
||||||
)
|
|
||||||
mrope_positions, mrope_position_delta = (
|
mrope_positions, mrope_position_delta = (
|
||||||
MRotaryEmbedding.get_input_positions(
|
MRotaryEmbedding.get_input_positions(
|
||||||
input_tokens=self.input_ids[
|
input_tokens=self.input_ids[
|
||||||
@@ -163,12 +169,9 @@ class ForwardBatch:
|
|||||||
mrope_positions_list[i] = mrope_positions
|
mrope_positions_list[i] = mrope_positions
|
||||||
batch.mrope_positions_delta[i].append(mrope_position_delta)
|
batch.mrope_positions_delta[i].append(mrope_position_delta)
|
||||||
|
|
||||||
self.mrope_positions = torch.tensor(
|
self.mrope_positions = torch.concat(
|
||||||
np.concatenate(
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||||
[np.array(pos) for pos in mrope_positions_list],
|
axis=1,
|
||||||
axis=1,
|
|
||||||
),
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||||
|
|
||||||
@@ -177,18 +180,15 @@ class ForwardBatch:
|
|||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
self.positions = (self.seq_lens - 1).to(torch.int64)
|
self.positions = (self.seq_lens - 1).to(torch.int64)
|
||||||
else:
|
else:
|
||||||
self.positions = torch.tensor(
|
self.positions = torch.concat(
|
||||||
np.concatenate(
|
[
|
||||||
[
|
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||||
np.arange(prefix_len, prefix_len + extend_len)
|
for prefix_len, extend_len in zip(
|
||||||
for prefix_len, extend_len in zip(
|
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
)
|
||||||
)
|
],
|
||||||
],
|
axis=0,
|
||||||
axis=0,
|
)
|
||||||
),
|
|
||||||
device=device,
|
|
||||||
).to(torch.int64)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
@@ -213,15 +213,6 @@ class ForwardBatch:
|
|||||||
|
|
||||||
# Init position information
|
# Init position information
|
||||||
if not ret.forward_mode.is_decode():
|
if not ret.forward_mode.is_decode():
|
||||||
ret.positions = torch.concat(
|
|
||||||
[
|
|
||||||
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
|
||||||
for prefix_len, extend_len in zip(
|
|
||||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
|
||||||
)
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
ret.image_inputs = batch.image_inputs
|
ret.image_inputs = batch.image_inputs
|
||||||
ret.extend_seq_lens = torch.tensor(
|
ret.extend_seq_lens = torch.tensor(
|
||||||
batch.extend_seq_lens, dtype=torch.int32
|
batch.extend_seq_lens, dtype=torch.int32
|
||||||
|
|||||||
@@ -362,10 +362,6 @@ class TestQWen2VLServer(TestOpenAIVisionServer):
|
|||||||
)
|
)
|
||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
def test_mixed_batch(self):
|
|
||||||
# FIXME: Temporarily skip this test.
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user