diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 15c2ba077..4c7620669 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -103,6 +103,15 @@ class GeluAndMul(CustomOp): raise RuntimeError("GeluAndMul only support tanh or none") return out + def forward_npu(self, x: torch.Tensor) -> torch.Tensor: + y_npu, gelu_npu = torch_npu.npu_geglu( + x, + dim=-1, + approximate=1 if self.approximate == "tanh" else 0, + activate_left=True, + ) + return y_npu + class NewGELU(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -137,6 +146,9 @@ class QuickGELU(CustomOp): gelu_quick(x, out) return out + def forward_npu(self, x: torch.Tensor) -> torch.Tensor: + return torch_npu.npu_fast_gelu(x) + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 0f826d2df..d4ede0a4c 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -64,7 +64,7 @@ class AscendAttnBackend(AttentionBackend): if self.use_mla: self.kv_lora_rank = model_runner.model_config.kv_lora_rank self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim - self.native_attn = TorchNativeAttnBackend(model_runner) + self.native_attn = TorchNativeAttnBackend(model_runner) self.graph_metadata = {} self.max_context_len = model_runner.model_config.context_len self.req_to_token = model_runner.req_to_token_pool.req_to_token @@ -180,7 +180,7 @@ class AscendAttnBackend(AttentionBackend): if self.use_fia: """FIA will support multi-bs in the later version of CANN""" - q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) attn_output = torch.empty( (q.size(0), layer.tp_q_head_num, layer.v_head_dim), device=q.device, @@ -208,26 +208,61 @@ class AscendAttnBackend(AttentionBackend): ) else: - query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim) - attn_output = torch.empty( - (query.shape[0], layer.tp_q_head_num * layer.v_head_dim), - dtype=query.dtype, - device=query.device, - ) + if layer.qk_head_dim <= 128: + query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) + attn_output = torch.empty( + (query.shape[0], layer.tp_q_head_num * layer.v_head_dim), + dtype=query.dtype, + device=query.device, + ) - torch_npu._npu_flash_attention_qlens( - query=query, - key_cache=k_cache, - value_cache=v_cache, - mask=self.mask, - block_table=self.forward_metadata.block_tables, - seq_len=self.forward_metadata.extend_seq_lens_cpu_int, - context_lens=self.forward_metadata.seq_lens_cpu_int, - scale_value=layer.scaling, - num_heads=layer.tp_q_head_num, - num_kv_heads=layer.tp_k_head_num, - out=attn_output, - ) + torch_npu._npu_flash_attention_qlens( + query=query, + key_cache=k_cache, + value_cache=v_cache, + mask=self.mask, + block_table=self.forward_metadata.block_tables, + seq_len=self.forward_metadata.extend_seq_lens_cpu_int, + context_lens=self.forward_metadata.seq_lens_cpu_int, + scale_value=layer.scaling, + num_heads=layer.tp_q_head_num, + num_kv_heads=layer.tp_k_head_num, + out=attn_output, + ) + else: + if layer.qk_head_dim != layer.v_head_dim: + attn_output = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + ) + else: + attn_output = torch.empty_like(q) + + use_gqa = layer.tp_q_head_num != layer.tp_k_head_num + + q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim) + + causal = True + if ( + layer.is_cross_attention + or layer.attn_type == AttentionType.ENCODER_ONLY + ): + causal = False + + self.native_attn._run_sdpa_forward_extend( + q_, + o_, + k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim), + v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim), + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.extend_prefix_lens, + forward_batch.extend_seq_lens, + scaling=layer.scaling, + enable_gqa=use_gqa, + causal=causal, + ) else: assert ( layer.qk_head_dim != layer.v_head_dim @@ -283,7 +318,7 @@ class AscendAttnBackend(AttentionBackend): v_cache = forward_batch.token_to_kv_pool.get_value_buffer( layer.layer_id ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) - query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) + query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) if self.forward_metadata.seq_lens_cpu_int is None: actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list else: @@ -439,7 +474,8 @@ class AscendAttnBackend(AttentionBackend): scale=layer.scaling, ) else: - query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim) + query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) + num_tokens = query.shape[0] attn_output = torch.empty( (num_tokens, layer.tp_q_head_num, layer.v_head_dim), dtype=query.dtype, diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index a77747351..cf8ccf4d1 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -53,7 +53,7 @@ elif _is_hip: logger = logging.getLogger(__name__) -if is_npu(): +if _is_npu: import torch_npu @@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp): out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) return out + def forward_npu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x -class Gemma3RMSNorm(nn.Module): + x = x.float() + variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True) + x = x * torch_npu.rsqrt(variance + self.variance_epsilon) + x = x * (1.0 + self.weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + +class Gemma3RMSNorm(CustomOp): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.zeros(dim)) + # Re-dispatch def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x): + def forward_native(self, x): output = self._norm(x.float()) # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 output = output * (1.0 + self.weight.float()) return output.type_as(x) + def forward_cuda(self, x): + return self.forward_native(x) + + def forward_npu(self, x): + output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps) + return output + def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index f3d82539f..7cffccf6b 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1876,7 +1876,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb( +def apply_rotary_pos_emb_native( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, @@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb( return q_embed, k_embed +def apply_rotary_pos_emb_npu( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim=1, +) -> Tuple[torch.Tensor, torch.Tensor]: + if q.shape[1] != 128: + return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + cos = torch.transpose(cos, 1, 2) + sin = sin.unsqueeze(unsqueeze_dim) + sin = torch.transpose(sin, 1, 2) + q = torch.transpose(q, 1, 2) + k = torch.transpose(k, 1, 2) + q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin) + q_embed = torch.transpose(q_embed, 1, 2) + k_embed = torch.transpose(k_embed, 1, 2) + return q_embed, k_embed + + +if _is_npu: + apply_rotary_pos_emb = apply_rotary_pos_emb_npu +else: + apply_rotary_pos_emb = apply_rotary_pos_emb_native + + def get_rope_cpu( head_size: int, rotary_dim: int, diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 7d4ae186a..bedf50a66 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import ( ) from sglang.srt.mem_cache.multimodal_cache import MultiModalCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import flatten_nested_list, print_warning_once +from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once from sglang.utils import logger +_is_npu = is_npu() + # NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger # to ensure consistent logging behavior across the codebase. This prevents issues with log # propagation that can cause some log messages (like 'server is fired up') to not appear @@ -486,6 +488,8 @@ def get_embedding_and_mask( if embedding is None: return None, None # 2. Get mask + if _is_npu: + torch.npu.current_stream().synchronize() special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor) # 3. Adjust embedding length if needed embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger) diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index d650535cb..cc14f691f 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -13,7 +13,9 @@ from PIL import Image from transformers import BaseImageProcessorFast from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem -from sglang.srt.utils import load_audio, load_image, load_video, logger +from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger + +_is_npu = is_npu() @dataclasses.dataclass @@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC): and isinstance(processor.image_processor, BaseImageProcessorFast) and not self.server_args.disable_fast_image_processor ): - kwargs["device"] = "cuda" + kwargs["device"] = "cuda" if not _is_npu else "npu" result = processor.__call__( text=[input_text], padding=True,