bugfix for qwen2_vl (#301)
### What this PR does / why we need it? this pr fixes the error while inferring Qwen2_VL. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? We've tested it on benchmark, it meets our satisfaction and is equal to gpu. --------- Signed-off-by: zouyida <zouyida@huawei.com>
This commit is contained in:
@@ -40,6 +40,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
|||||||
|
|
||||||
class CustomQwen2VisionAttention(Qwen2VisionAttention):
|
class CustomQwen2VisionAttention(Qwen2VisionAttention):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
projection_size: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
projection_size,
|
||||||
|
quant_config,
|
||||||
|
prefix,
|
||||||
|
)
|
||||||
|
self.cu_seqlens = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -47,6 +64,8 @@ class CustomQwen2VisionAttention(Qwen2VisionAttention):
|
|||||||
rotary_pos_emb: torch.Tensor,
|
rotary_pos_emb: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
self.cu_seqlens = cu_seqlens
|
||||||
|
|
||||||
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
||||||
x, _ = self.qkv(x)
|
x, _ = self.qkv(x)
|
||||||
|
|
||||||
@@ -72,7 +91,7 @@ class CustomQwen2VisionAttention(Qwen2VisionAttention):
|
|||||||
query=q,
|
query=q,
|
||||||
key=k,
|
key=k,
|
||||||
value=v,
|
value=v,
|
||||||
seq_len=cu_seqlens,
|
seq_len=self.cu_seqlens,
|
||||||
scale_value=self.hidden_size_per_attention_head**-0.5,
|
scale_value=self.hidden_size_per_attention_head**-0.5,
|
||||||
num_heads=self.num_attention_heads_per_partition,
|
num_heads=self.num_attention_heads_per_partition,
|
||||||
num_kv_heads=self.num_attention_heads_per_partition,
|
num_kv_heads=self.num_attention_heads_per_partition,
|
||||||
|
|||||||
Reference in New Issue
Block a user