diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 2be3e450b..489b8248b 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -16,14 +16,19 @@ from sglang.srt.utils import ( get_device_capability, is_blackwell, is_cuda, + is_npu, print_info_once, ) _is_cuda = is_cuda() +_is_npu = is_npu() if _is_cuda: from sgl_kernel.flash_attn import flash_attn_varlen_func +if _is_npu: + import torch_npu + from sglang.srt.distributed import ( split_tensor_along_last_dim, tensor_model_parallel_all_gather, @@ -331,10 +336,63 @@ class VisionFlash3Attention(nn.Module): return output +class VisionAscendAttention(nn.Module): + + def __init__( + self, + **kwargs, + ): + if not _is_npu: + raise Exception("VisionAscendAttention is only available for ascend npu") + super().__init__() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: Optional[Union[SingletonCache, torch.Tensor]], + bsz: int, + seq_len: int, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + if cu_seqlens is None: + cu_seqlens = _get_cu_seqlens_for_shape(bsz, seq_len, device=q.device) + + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + if seq_lens.is_npu: + # cu_seqlens must be on cpu because of operator restriction + seq_lens = seq_lens.to("cpu") + _, num_heads, head_size = q.shape + num_kv_heads = k.shape[1] + output = torch.empty_like(q) + + # operator requires pta version >= 2.5.1 + torch_npu._npu_flash_attention_unpad( + query=q, + key=k, + value=v, + seq_len=seq_lens.to(torch.int32), + scale_value=head_size**-0.5, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + out=output, + ) + + return output + + QKV_BACKEND_IMPL = { "triton_attn": VisionTritonAttention, "sdpa": VisionSdpaAttention, "fa3": VisionFlash3Attention, + "ascend_attn": VisionAscendAttention, } diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index f0e9e5a7b..eacf84c8a 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -12,6 +12,7 @@ from sglang.srt.custom_op import CustomOp from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, + get_compiler_backend, is_cpu, is_cuda, is_hip, @@ -33,6 +34,9 @@ if _use_aiter: if is_npu(): import torch_npu + NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000 + NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896 + def _rotate_neox(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] @@ -1035,7 +1039,7 @@ class MRotaryEmbedding(RotaryEmbedding): f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" ) - @torch.compile(dynamic=True) + @torch.compile(dynamic=True, backend=get_compiler_backend()) def forward( self, positions: torch.Tensor, @@ -1894,17 +1898,30 @@ def apply_rotary_pos_emb_npu( sin: torch.Tensor, unsqueeze_dim=1, ) -> Tuple[torch.Tensor, torch.Tensor]: - if q.shape[1] != 128: + """Ascend implementation equivalent to apply_rotary_pos_emb_native. + + Args: + q: [num_tokens, num_heads, head_size] + k: [num_tokens, num_kv_heads, head_size] + cos: [num_tokens, head_size] + sin: [num_tokens, head_size] + """ + if ( + cos.dim() != 2 + or q.dim() != 3 + or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS + or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE + ): + # Note: num_heads and head_size of q must be less than 1000 and 896, respectively return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim) - cos = cos.unsqueeze(unsqueeze_dim) - cos = torch.transpose(cos, 1, 2) - sin = sin.unsqueeze(unsqueeze_dim) - sin = torch.transpose(sin, 1, 2) - q = torch.transpose(q, 1, 2) - k = torch.transpose(k, 1, 2) - q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin) - q_embed = torch.transpose(q_embed, 1, 2) - k_embed = torch.transpose(k_embed, 1, 2) + cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0) + sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin) + k_embed = torch_npu.npu_rotary_mul(k, cos, sin) + q_embed = q_embed.squeeze(0) + k_embed = k_embed.squeeze(0) return q_embed, k_embed diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index ab9c69fc2..e5bf320be 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -206,7 +206,10 @@ def _initialize_model( if _is_npu: packed_modules_mapping.update( { - "visual": {"qkv_proj": ["qkv"]}, + "visual": { + "qkv_proj": ["qkv"], + "gate_up_proj": ["gate_proj", "up_proj"], + }, "vision_model": { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "proj": ["out_proj"], diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index e5da78368..ef076ae09 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -234,7 +234,14 @@ class BaseMultimodalProcessor(ABC): and isinstance(processor.image_processor, BaseImageProcessorFast) and not self.server_args.disable_fast_image_processor ): - kwargs["device"] = "cuda" if not _is_npu else "npu" + if not _is_npu: + kwargs["device"] = "cuda" + elif processor.__class__.__name__ not in { + "Qwen2_5_VLProcessor", + "Qwen3VLProcessor", + }: + # Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend. + kwargs["device"] = "npu" result = processor.__call__( text=[input_text], padding=True, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 845baff56..2aa1e9031 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1840,7 +1840,7 @@ class ServerArgs: parser.add_argument( "--mm-attention-backend", type=str, - choices=["sdpa", "fa3", "triton_attn"], + choices=["sdpa", "fa3", "triton_attn", "ascend_attn"], default=ServerArgs.mm_attention_backend, help="Set multimodal attention backend.", )