[Feature] Optimize DeepSeek's DeepEP on Ascend NPU (#8355)
Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: Hexq0210 <hexq0809521@gmail.com>
This commit is contained in:
@@ -50,6 +50,8 @@ from sglang.srt.utils import (
|
|||||||
supports_custom_op,
|
supports_custom_op,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_is_npu = is_npu()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphCaptureContext:
|
class GraphCaptureContext:
|
||||||
@@ -591,7 +593,7 @@ class GroupCoordinator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||||
if not supports_custom_op():
|
if _is_npu or not supports_custom_op():
|
||||||
self._all_gather_into_tensor(output, input)
|
self._all_gather_into_tensor(output, input)
|
||||||
else:
|
else:
|
||||||
torch.ops.sglang.reg_all_gather_into_tensor(
|
torch.ops.sglang.reg_all_gather_into_tensor(
|
||||||
@@ -1127,7 +1129,7 @@ def init_model_parallel_group(
|
|||||||
group_ranks=group_ranks,
|
group_ranks=group_ranks,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
torch_distributed_backend=backend,
|
torch_distributed_backend=backend,
|
||||||
use_pynccl=not is_npu(),
|
use_pynccl=not _is_npu,
|
||||||
use_pymscclpp=use_mscclpp_allreduce,
|
use_pymscclpp=use_mscclpp_allreduce,
|
||||||
use_custom_allreduce=use_custom_allreduce,
|
use_custom_allreduce=use_custom_allreduce,
|
||||||
use_hpu_communicator=True,
|
use_hpu_communicator=True,
|
||||||
|
|||||||
@@ -75,6 +75,9 @@ class AscendAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
||||||
|
|
||||||
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
self,
|
self,
|
||||||
q,
|
q,
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
|
AscendDeepEPLLOutput,
|
||||||
DeepEPLLOutput,
|
DeepEPLLOutput,
|
||||||
DeepEPNormalOutput,
|
DeepEPNormalOutput,
|
||||||
DispatchOutput,
|
DispatchOutput,
|
||||||
@@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE):
|
|||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.deepep_mode.enable_low_latency():
|
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||||
|
# NPU supports low_latency deepep without deepgemm
|
||||||
assert (
|
assert (
|
||||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||||
@@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
# the last one is invalid rank_id
|
# the last one is invalid rank_id
|
||||||
self.expert_mask[:-1] = 1
|
self.expert_mask[:-1] = 1
|
||||||
else:
|
elif not _is_npu:
|
||||||
self.w13_weight_fp8 = (
|
self.w13_weight_fp8 = (
|
||||||
self.w13_weight,
|
self.w13_weight,
|
||||||
(
|
(
|
||||||
@@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE):
|
|||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||||
return self.forward_aiter(dispatch_output)
|
return self.forward_aiter(dispatch_output)
|
||||||
|
if _is_npu:
|
||||||
|
return self.forward_npu(dispatch_output)
|
||||||
if dispatch_output.format.is_deepep_normal():
|
if dispatch_output.format.is_deepep_normal():
|
||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||||
@@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE):
|
|||||||
|
|
||||||
return down_output
|
return down_output
|
||||||
|
|
||||||
|
def forward_npu(
|
||||||
|
self,
|
||||||
|
dispatch_output: DeepEPLLOutput,
|
||||||
|
):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
||||||
|
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
||||||
|
assert self.quant_method is not None
|
||||||
|
assert self.activation == "silu"
|
||||||
|
|
||||||
|
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
||||||
|
output_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
pertoken_scale = hidden_states[1]
|
||||||
|
hidden_states = hidden_states[0]
|
||||||
|
|
||||||
|
group_list_type = 1
|
||||||
|
seg_indptr = seg_indptr.to(torch.int64)
|
||||||
|
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
# gmm1: gate_up_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w13_weight],
|
||||||
|
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||||
|
per_token_scale=[pertoken_scale],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=seg_indptr,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# act_fn: swiglu
|
||||||
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||||
|
|
||||||
|
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||||
|
|
||||||
|
# gmm2: down_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[self.w2_weight],
|
||||||
|
scale=[self.w2_weight_scale.to(output_dtype)],
|
||||||
|
per_token_scale=[swiglu_out_scale],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=seg_indptr,
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def get_moe_impl_class():
|
def get_moe_impl_class():
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||||
|
|||||||
@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
|||||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
|
from sglang.srt.utils import (
|
||||||
|
get_bool_env_var,
|
||||||
|
get_int_env_var,
|
||||||
|
is_hip,
|
||||||
|
is_npu,
|
||||||
|
load_json_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
_is_npu = is_npu()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deep_ep import Buffer, Config
|
from deep_ep import Buffer, Config
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
if not _is_npu:
|
||||||
sglang_per_token_group_quant_fp8,
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
)
|
sglang_per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
use_deepep = True
|
use_deepep = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple):
|
|||||||
return DispatchOutputFormat.deepep_ll
|
return DispatchOutputFormat.deepep_ll
|
||||||
|
|
||||||
|
|
||||||
|
class AscendDeepEPLLOutput(NamedTuple):
|
||||||
|
"""AscendDeepEP low latency dispatch output."""
|
||||||
|
|
||||||
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
topk_idx: torch.Tensor
|
||||||
|
topk_weights: torch.Tensor
|
||||||
|
masked_m: torch.Tensor
|
||||||
|
seg_indptr: torch.Tensor
|
||||||
|
expected_m: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self) -> DispatchOutputFormat:
|
||||||
|
return DispatchOutputFormat.deepep_ll
|
||||||
|
|
||||||
|
|
||||||
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
||||||
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
||||||
|
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
|
||||||
|
|
||||||
|
|
||||||
class DeepEPDispatchMode(IntEnum):
|
class DeepEPDispatchMode(IntEnum):
|
||||||
@@ -150,19 +175,20 @@ class DeepEPBuffer:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
total_num_sms = torch.cuda.get_device_properties(
|
if not _is_npu:
|
||||||
device="cuda"
|
total_num_sms = torch.cuda.get_device_properties(
|
||||||
).multi_processor_count
|
device="cuda"
|
||||||
if (
|
).multi_processor_count
|
||||||
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
if (
|
||||||
and not global_server_args_dict["enable_two_batch_overlap"]
|
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
||||||
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
and not global_server_args_dict["enable_two_batch_overlap"]
|
||||||
):
|
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
||||||
logger.warning(
|
):
|
||||||
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
|
logger.warning(
|
||||||
f"This may result in highly suboptimal performance. "
|
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
|
||||||
f"Consider using --deepep-config to change the behavior."
|
f"This may result in highly suboptimal performance. "
|
||||||
)
|
f"Consider using --deepep-config to change the behavior."
|
||||||
|
)
|
||||||
|
|
||||||
cls._buffer = Buffer(
|
cls._buffer = Buffer(
|
||||||
group,
|
group,
|
||||||
@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
masked_m
|
masked_m
|
||||||
)
|
)
|
||||||
|
|
||||||
return DeepEPLLOutput(
|
if _is_npu:
|
||||||
hidden_states,
|
deepep_output = AscendDeepEPLLOutput(
|
||||||
topk_idx,
|
hidden_states,
|
||||||
topk_weights,
|
topk_idx,
|
||||||
masked_m,
|
topk_weights,
|
||||||
expected_m,
|
masked_m,
|
||||||
)
|
self.handle[1],
|
||||||
|
expected_m,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
deepep_output = DeepEPLLOutput(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
)
|
||||||
|
return deepep_output
|
||||||
|
|
||||||
def _dispatch_core(
|
def _dispatch_core(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -250,10 +250,11 @@ class TopK(CustomOp):
|
|||||||
|
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
if global_num_experts == 256:
|
if global_num_experts == 256:
|
||||||
|
router_logits = router_logits.to(torch.float32)
|
||||||
return torch_npu.npu_moe_gating_top_k(
|
return torch_npu.npu_moe_gating_top_k(
|
||||||
router_logits,
|
router_logits,
|
||||||
k=self.top_k,
|
k=self.top_k,
|
||||||
bias=self.correction_bias,
|
bias=self.correction_bias.to(torch.float32),
|
||||||
k_group=self.topk_group,
|
k_group=self.topk_group,
|
||||||
group_count=self.num_expert_group,
|
group_count=self.num_expert_group,
|
||||||
group_select_mode=1,
|
group_select_mode=1,
|
||||||
|
|||||||
@@ -3,7 +3,18 @@ from __future__ import annotations
|
|||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
@@ -79,22 +90,16 @@ def npu_wrapper_rmsnorm_forward(func):
|
|||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
if not x.is_contiguous():
|
if not x.is_contiguous():
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
original_dtype = x.dtype
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
x = x + residual.to(torch.float32)
|
out, _, residual_out = torch_npu.npu_add_rms_norm(
|
||||||
residual = x.to(original_dtype)
|
residual, x, self.weight.data, self.variance_epsilon
|
||||||
|
)
|
||||||
|
out = out + self.bias
|
||||||
|
return out.to(x.dtype), residual_out
|
||||||
|
|
||||||
x = (
|
out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
|
||||||
torch_npu.npu_rms_norm(
|
out = out + self.bias
|
||||||
x, self.weight.to(torch.float32), self.variance_epsilon
|
return out.to(x.dtype)
|
||||||
)[0]
|
|
||||||
+ self.bias
|
|
||||||
)
|
|
||||||
|
|
||||||
if residual is None:
|
|
||||||
return x.to(original_dtype)
|
|
||||||
return x.to(original_dtype), residual
|
|
||||||
|
|
||||||
return _rmsnorm_forward_oot
|
return _rmsnorm_forward_oot
|
||||||
|
|
||||||
@@ -571,8 +576,10 @@ class NPU_W8A8LinearMethodImpl:
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
tp_rank: Optional[int] = 0,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# To prevent import loops
|
||||||
|
from sglang.srt.layers.linear import RowParallelLinear
|
||||||
|
|
||||||
original_dtype = x.dtype
|
original_dtype = x.dtype
|
||||||
if original_dtype != torch.int8:
|
if original_dtype != torch.int8:
|
||||||
x = torch_npu.npu_quantize(
|
x = torch_npu.npu_quantize(
|
||||||
@@ -583,8 +590,12 @@ class NPU_W8A8LinearMethodImpl:
|
|||||||
-1,
|
-1,
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
# bias will not get added more than once in Attention TP>1 case)
|
||||||
|
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
|
||||||
|
quant_bias = None
|
||||||
|
else:
|
||||||
|
quant_bias = layer.quant_bias
|
||||||
return torch_npu.npu_quant_matmul(
|
return torch_npu.npu_quant_matmul(
|
||||||
x,
|
x,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
@@ -651,13 +662,21 @@ class NPU_W8A8LinearMethodMTImpl:
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
tp_rank: Optional[int] = 0,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# To prevent import loops
|
||||||
|
from sglang.srt.layers.linear import RowParallelLinear
|
||||||
|
|
||||||
original_dtype = x.dtype
|
original_dtype = x.dtype
|
||||||
if original_dtype != torch.int8:
|
if original_dtype != torch.int8:
|
||||||
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
||||||
|
|
||||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
|
# bias will not get added more than once in Attention TP>1 case)
|
||||||
|
if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0:
|
||||||
|
quant_bias = None
|
||||||
|
else:
|
||||||
|
quant_bias = layer.quant_bias
|
||||||
|
|
||||||
return ops.quant_matmul(
|
return ops.quant_matmul(
|
||||||
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
|
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
|
||||||
)
|
)
|
||||||
@@ -737,11 +756,6 @@ class NPU_W8A8LinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.linear import RowParallelLinear
|
|
||||||
|
|
||||||
if isinstance(layer, RowParallelLinear):
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
|
||||||
return self.quant_method.apply(layer, x, bias)
|
return self.quant_method.apply(layer, x, bias)
|
||||||
|
|
||||||
|
|
||||||
@@ -780,7 +794,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
|||||||
tp_rank: Optional[int] = 0,
|
tp_rank: Optional[int] = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
original_dtype = x.dtype
|
original_dtype = x.dtype
|
||||||
# use ATB quantize
|
|
||||||
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||||
return torch_npu.npu_quant_matmul(
|
return torch_npu.npu_quant_matmul(
|
||||||
quant_out,
|
quant_out,
|
||||||
@@ -863,11 +876,6 @@ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.linear import RowParallelLinear
|
|
||||||
|
|
||||||
if isinstance(layer, RowParallelLinear):
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
|
||||||
return self.quant_method.apply(layer, x, bias)
|
return self.quant_method.apply(layer, x, bias)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -680,7 +680,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Re-dispatch
|
# Re-dispatch
|
||||||
if _is_hip or _is_npu:
|
if _is_hip:
|
||||||
self._forward_method = self.forward_native
|
self._forward_method = self.forward_native
|
||||||
|
|
||||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||||
@@ -765,6 +765,46 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
key = key_rot
|
key = key_rot
|
||||||
return query.to(dtype), key.to(dtype)
|
return query.to(dtype), key.to(dtype)
|
||||||
|
|
||||||
|
def forward_npu(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
|
||||||
|
# and generalization to more scenarios will be supported in the future.
|
||||||
|
if query.shape[1] * query.shape[2] > 4096:
|
||||||
|
return self.forward_native(positions, query, key, offsets)
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
rotary_mode = "half" if self.is_neox_style else "interleave"
|
||||||
|
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
|
||||||
|
query_rot = query[..., : self.rotary_dim]
|
||||||
|
key_rot = key[..., : self.rotary_dim]
|
||||||
|
if self.rotary_dim < self.head_size:
|
||||||
|
query_pass = query[..., self.rotary_dim :]
|
||||||
|
key_pass = key[..., self.rotary_dim :]
|
||||||
|
|
||||||
|
query_rot, key_rot = torch_npu.npu_mrope(
|
||||||
|
torch.add(positions, offsets) if offsets is not None else positions,
|
||||||
|
query_rot.reshape(num_tokens, -1),
|
||||||
|
key_rot.reshape(num_tokens, -1),
|
||||||
|
self.cos_sin_cache,
|
||||||
|
self.rotary_dim,
|
||||||
|
mrope_section=[0, 0, 0],
|
||||||
|
rotary_mode=rotary_mode,
|
||||||
|
)
|
||||||
|
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
|
||||||
|
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
|
||||||
|
|
||||||
|
if self.rotary_dim < self.head_size:
|
||||||
|
query = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
|
key = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
else:
|
||||||
|
query = query_rot
|
||||||
|
key = key_rot
|
||||||
|
return query, key
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user