[feat]Ascend NPU Gemma-3-12b and Gemma-3-27b support (#8909)
This commit is contained in:
@@ -103,6 +103,15 @@ class GeluAndMul(CustomOp):
|
||||
raise RuntimeError("GeluAndMul only support tanh or none")
|
||||
return out
|
||||
|
||||
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y_npu, gelu_npu = torch_npu.npu_geglu(
|
||||
x,
|
||||
dim=-1,
|
||||
approximate=1 if self.approximate == "tanh" else 0,
|
||||
activate_left=True,
|
||||
)
|
||||
return y_npu
|
||||
|
||||
|
||||
class NewGELU(CustomOp):
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -137,6 +146,9 @@ class QuickGELU(CustomOp):
|
||||
gelu_quick(x, out)
|
||||
return out
|
||||
|
||||
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch_npu.npu_fast_gelu(x)
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
@@ -64,7 +64,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
if self.use_mla:
|
||||
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
||||
self.native_attn = TorchNativeAttnBackend(model_runner)
|
||||
self.native_attn = TorchNativeAttnBackend(model_runner)
|
||||
self.graph_metadata = {}
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
@@ -180,7 +180,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
|
||||
if self.use_fia:
|
||||
"""FIA will support multi-bs in the later version of CANN"""
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(q.size(0), layer.tp_q_head_num, layer.v_head_dim),
|
||||
device=q.device,
|
||||
@@ -208,26 +208,61 @@ class AscendAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
else:
|
||||
query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
if layer.qk_head_dim <= 128:
|
||||
query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
attn_output = torch.empty(
|
||||
(query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
mask=self.mask,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale_value=layer.scaling,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
out=attn_output,
|
||||
)
|
||||
torch_npu._npu_flash_attention_qlens(
|
||||
query=query,
|
||||
key_cache=k_cache,
|
||||
value_cache=v_cache,
|
||||
mask=self.mask,
|
||||
block_table=self.forward_metadata.block_tables,
|
||||
seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
|
||||
context_lens=self.forward_metadata.seq_lens_cpu_int,
|
||||
scale_value=layer.scaling,
|
||||
num_heads=layer.tp_q_head_num,
|
||||
num_kv_heads=layer.tp_k_head_num,
|
||||
out=attn_output,
|
||||
)
|
||||
else:
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
attn_output = q.new_empty(
|
||||
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
||||
)
|
||||
else:
|
||||
attn_output = torch.empty_like(q)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
|
||||
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
o_ = attn_output.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
|
||||
causal = True
|
||||
if (
|
||||
layer.is_cross_attention
|
||||
or layer.attn_type == AttentionType.ENCODER_ONLY
|
||||
):
|
||||
causal = False
|
||||
|
||||
self.native_attn._run_sdpa_forward_extend(
|
||||
q_,
|
||||
o_,
|
||||
k_cache.view(-1, layer.tp_k_head_num, layer.qk_head_dim),
|
||||
v_cache.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
||||
forward_batch.req_to_token_pool.req_to_token,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.extend_prefix_lens,
|
||||
forward_batch.extend_seq_lens,
|
||||
scaling=layer.scaling,
|
||||
enable_gqa=use_gqa,
|
||||
causal=causal,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
layer.qk_head_dim != layer.v_head_dim
|
||||
@@ -283,7 +318,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
|
||||
layer.layer_id
|
||||
).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
|
||||
query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
|
||||
if self.forward_metadata.seq_lens_cpu_int is None:
|
||||
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
||||
else:
|
||||
@@ -439,7 +474,8 @@ class AscendAttnBackend(AttentionBackend):
|
||||
scale=layer.scaling,
|
||||
)
|
||||
else:
|
||||
query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
||||
num_tokens = query.shape[0]
|
||||
attn_output = torch.empty(
|
||||
(num_tokens, layer.tp_q_head_num, layer.v_head_dim),
|
||||
dtype=query.dtype,
|
||||
|
||||
@@ -53,7 +53,7 @@ elif _is_hip:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_npu():
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
|
||||
@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp):
|
||||
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
residual = x
|
||||
|
||||
class Gemma3RMSNorm(nn.Module):
|
||||
x = x.float()
|
||||
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
|
||||
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
|
||||
x = x * (1.0 + self.weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
|
||||
class Gemma3RMSNorm(CustomOp):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
# Re-dispatch
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
def forward_native(self, x):
|
||||
output = self._norm(x.float())
|
||||
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
output = output * (1.0 + self.weight.float())
|
||||
return output.type_as(x)
|
||||
|
||||
def forward_cuda(self, x):
|
||||
return self.forward_native(x)
|
||||
|
||||
def forward_npu(self, x):
|
||||
output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
|
||||
return output
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
@@ -1876,7 +1876,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
def apply_rotary_pos_emb_native(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
@@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb(
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_npu(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
unsqueeze_dim=1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if q.shape[1] != 128:
|
||||
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
cos = torch.transpose(cos, 1, 2)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
sin = torch.transpose(sin, 1, 2)
|
||||
q = torch.transpose(q, 1, 2)
|
||||
k = torch.transpose(k, 1, 2)
|
||||
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
|
||||
q_embed = torch.transpose(q_embed, 1, 2)
|
||||
k_embed = torch.transpose(k_embed, 1, 2)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
if _is_npu:
|
||||
apply_rotary_pos_emb = apply_rotary_pos_emb_npu
|
||||
else:
|
||||
apply_rotary_pos_emb = apply_rotary_pos_emb_native
|
||||
|
||||
|
||||
def get_rope_cpu(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
|
||||
@@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import (
|
||||
)
|
||||
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
||||
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
|
||||
from sglang.utils import logger
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
|
||||
# to ensure consistent logging behavior across the codebase. This prevents issues with log
|
||||
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
||||
@@ -486,6 +488,8 @@ def get_embedding_and_mask(
|
||||
if embedding is None:
|
||||
return None, None
|
||||
# 2. Get mask
|
||||
if _is_npu:
|
||||
torch.npu.current_stream().synchronize()
|
||||
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
|
||||
# 3. Adjust embedding length if needed
|
||||
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
|
||||
|
||||
@@ -13,7 +13,9 @@ from PIL import Image
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
||||
from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
and isinstance(processor.image_processor, BaseImageProcessorFast)
|
||||
and not self.server_args.disable_fast_image_processor
|
||||
):
|
||||
kwargs["device"] = "cuda"
|
||||
kwargs["device"] = "cuda" if not _is_npu else "npu"
|
||||
result = processor.__call__(
|
||||
text=[input_text],
|
||||
padding=True,
|
||||
|
||||
Reference in New Issue
Block a user