bugfix for qwen2_vl (#301)
### What this PR does / why we need it? this pr fixes the error while inferring Qwen2_VL. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? We've tested it on benchmark, it meets our satisfaction and is equal to gpu. --------- Signed-off-by: zouyida <zouyida@huawei.com>
This commit is contained in:
@@ -40,6 +40,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
class CustomQwen2VisionAttention(Qwen2VisionAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
embed_dim,
|
||||
num_heads,
|
||||
projection_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.cu_seqlens = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -47,6 +64,8 @@ class CustomQwen2VisionAttention(Qwen2VisionAttention):
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
self.cu_seqlens = cu_seqlens
|
||||
|
||||
# [s, b, c] --> [s, b, 3 * head * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
@@ -72,7 +91,7 @@ class CustomQwen2VisionAttention(Qwen2VisionAttention):
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=cu_seqlens,
|
||||
seq_len=self.cu_seqlens,
|
||||
scale_value=self.hidden_size_per_attention_head**-0.5,
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
num_kv_heads=self.num_attention_heads_per_partition,
|
||||
|
||||
Reference in New Issue
Block a user