[Ascend] optimize Qwen-vl on Ascend (#10556)
Co-authored-by: wangqihui01 <wangqh10@163.com>
This commit is contained in:
@@ -16,14 +16,19 @@ from sglang.srt.utils import (
|
||||
get_device_capability,
|
||||
is_blackwell,
|
||||
is_cuda,
|
||||
is_npu,
|
||||
print_info_once,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -331,10 +336,63 @@ class VisionFlash3Attention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class VisionAscendAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
if not _is_npu:
|
||||
raise Exception("VisionAscendAttention is only available for ascend npu")
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]],
|
||||
bsz: int,
|
||||
seq_len: int,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
cu_seqlens: [b]
|
||||
Returns:
|
||||
[b * s, h, head_size]
|
||||
"""
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device)
|
||||
|
||||
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
if seq_lens.is_npu:
|
||||
# cu_seqlens must be on cpu because of operator restriction
|
||||
seq_lens = seq_lens.to("cpu")
|
||||
_, num_heads, head_size = q.shape
|
||||
num_kv_heads = k.shape[1]
|
||||
output = torch.empty_like(q)
|
||||
|
||||
# operator requires pta version >= 2.5.1
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
seq_len=seq_lens.to(torch.int32),
|
||||
scale_value=head_size**-0.5,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
out=output,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
QKV_BACKEND_IMPL = {
|
||||
"triton_attn": VisionTritonAttention,
|
||||
"sdpa": VisionSdpaAttention,
|
||||
"fa3": VisionFlash3Attention,
|
||||
"ascend_attn": VisionAscendAttention,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
get_compiler_backend,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
@@ -33,6 +34,9 @@ if _use_aiter:
|
||||
if is_npu():
|
||||
import torch_npu
|
||||
|
||||
NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
|
||||
NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
|
||||
|
||||
|
||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
@@ -1035,7 +1039,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
||||
)
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -1894,17 +1898,30 @@ def apply_rotary_pos_emb_npu(
|
||||
sin: torch.Tensor,
|
||||
unsqueeze_dim=1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if q.shape[1] != 128:
|
||||
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
|
||||
|
||||
Args:
|
||||
q: [num_tokens, num_heads, head_size]
|
||||
k: [num_tokens, num_kv_heads, head_size]
|
||||
cos: [num_tokens, head_size]
|
||||
sin: [num_tokens, head_size]
|
||||
"""
|
||||
if (
|
||||
cos.dim() != 2
|
||||
or q.dim() != 3
|
||||
or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
|
||||
or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
|
||||
):
|
||||
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
|
||||
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
cos = torch.transpose(cos, 1, 2)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
sin = torch.transpose(sin, 1, 2)
|
||||
q = torch.transpose(q, 1, 2)
|
||||
k = torch.transpose(k, 1, 2)
|
||||
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
|
||||
q_embed = torch.transpose(q_embed, 1, 2)
|
||||
k_embed = torch.transpose(k_embed, 1, 2)
|
||||
cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
||||
sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
q_embed = q_embed.squeeze(0)
|
||||
k_embed = k_embed.squeeze(0)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
@@ -206,7 +206,10 @@ def _initialize_model(
|
||||
if _is_npu:
|
||||
packed_modules_mapping.update(
|
||||
{
|
||||
"visual": {"qkv_proj": ["qkv"]},
|
||||
"visual": {
|
||||
"qkv_proj": ["qkv"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
},
|
||||
"vision_model": {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"proj": ["out_proj"],
|
||||
|
||||
@@ -234,7 +234,14 @@ class BaseMultimodalProcessor(ABC):
|
||||
and isinstance(processor.image_processor, BaseImageProcessorFast)
|
||||
and not self.server_args.disable_fast_image_processor
|
||||
):
|
||||
kwargs["device"] = "cuda" if not _is_npu else "npu"
|
||||
if not _is_npu:
|
||||
kwargs["device"] = "cuda"
|
||||
elif processor.__class__.__name__ not in {
|
||||
"Qwen2_5_VLProcessor",
|
||||
"Qwen3VLProcessor",
|
||||
}:
|
||||
# Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend.
|
||||
kwargs["device"] = "npu"
|
||||
result = processor.__call__(
|
||||
text=[input_text],
|
||||
padding=True,
|
||||
|
||||
@@ -1840,7 +1840,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--mm-attention-backend",
|
||||
type=str,
|
||||
choices=["sdpa", "fa3", "triton_attn"],
|
||||
choices=["sdpa", "fa3", "triton_attn", "ascend_attn"],
|
||||
default=ServerArgs.mm_attention_backend,
|
||||
help="Set multimodal attention backend.",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user