Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -83,15 +83,16 @@ from .vision import (
|
||||
resolve_visual_encoder_outputs,
|
||||
)
|
||||
|
||||
import ixformer.inference.functions as ixf
|
||||
try:
|
||||
# Note: vLLM does not install xformers by default.
|
||||
from xformers import ops as xops
|
||||
|
||||
if current_platform.is_cuda() and current_platform.has_device_capability(100):
|
||||
if current_platform.is_cuda():
|
||||
# Xformers FA is not compatible with B200
|
||||
USE_XFORMERS_OPS = False
|
||||
else:
|
||||
USE_XFORMERS_OPS = True
|
||||
else:
|
||||
USE_XFORMERS_OPS = False
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
@@ -698,23 +699,21 @@ class Attention(nn.Module):
|
||||
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
||||
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||
|
||||
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
||||
|
||||
if USE_XFORMERS_OPS:
|
||||
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
v = v.reshape(batch * patches, self.n_heads, self.head_dim)
|
||||
|
||||
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
||||
q = q.view(batch * patches, self.n_heads, self.head_dim)
|
||||
k = k.view(batch * patches, self.n_heads, self.head_dim)
|
||||
out = ixf.ixinfer_flash_attn_unpad(q,k,v, mask.q_seqinfo.seqstart.to(q.device), mask.k_seqinfo.seqstart.to(q.device), mask.q_seqinfo.max_seqlen, mask.k_seqinfo.max_seqlen)
|
||||
# out = memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
else:
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
assert False, "xformers failed !"
|
||||
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
||||
return self.wo(out)
|
||||
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: VisionEncoderArgs):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user