[Bugfix] qwen2vl forward_extend (#1727)
This commit is contained in:
@@ -35,7 +35,6 @@ from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
@@ -134,16 +133,23 @@ class ForwardBatch:
|
||||
)
|
||||
elif self.forward_mode.is_extend():
|
||||
for i, image_inputs in enumerate(batch.image_inputs):
|
||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||
self.extend_start_loc[i],
|
||||
self.extend_seq_lens[i],
|
||||
self.extend_prefix_lens[i],
|
||||
)
|
||||
if image_inputs is None:
|
||||
# text only
|
||||
mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3
|
||||
mrope_positions = [
|
||||
[
|
||||
pos
|
||||
for pos in range(
|
||||
extend_prefix_len, extend_prefix_len + extend_seq_len
|
||||
)
|
||||
]
|
||||
] * 3
|
||||
mrope_position_delta = 0
|
||||
else:
|
||||
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
||||
self.extend_start_loc[i],
|
||||
self.extend_seq_lens[i],
|
||||
self.extend_prefix_lens[i],
|
||||
)
|
||||
mrope_positions, mrope_position_delta = (
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
input_tokens=self.input_ids[
|
||||
@@ -163,12 +169,9 @@ class ForwardBatch:
|
||||
mrope_positions_list[i] = mrope_positions
|
||||
batch.mrope_positions_delta[i].append(mrope_position_delta)
|
||||
|
||||
self.mrope_positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[np.array(pos) for pos in mrope_positions_list],
|
||||
axis=1,
|
||||
),
|
||||
device=device,
|
||||
self.mrope_positions = torch.concat(
|
||||
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
||||
axis=1,
|
||||
)
|
||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||
|
||||
@@ -177,18 +180,15 @@ class ForwardBatch:
|
||||
if self.forward_mode.is_decode():
|
||||
self.positions = (self.seq_lens - 1).to(torch.int64)
|
||||
else:
|
||||
self.positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
np.arange(prefix_len, prefix_len + extend_len)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
),
|
||||
device=device,
|
||||
).to(torch.int64)
|
||||
self.positions = torch.concat(
|
||||
[
|
||||
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
@@ -213,15 +213,6 @@ class ForwardBatch:
|
||||
|
||||
# Init position information
|
||||
if not ret.forward_mode.is_decode():
|
||||
ret.positions = torch.concat(
|
||||
[
|
||||
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
ret.image_inputs = batch.image_inputs
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
|
||||
Reference in New Issue
Block a user