[feat][torchair] support super kernel feat for quantized dsr1 (#3485)
### What this PR does / why we need it? Port #1916 and #2157 to master branch to fuse operators in deepseek moe layers, which can reduce scheduling overhead on devices. Note that this feature is valid only when `tp_size = 1` and `multistream_overlap_shared_expert` is enabled with torchair graph mode. ### Does this PR introduce _any_ user-facing change? Users can enable this feature with `--additional-config '{"torchair_graph_config":{"enabled":true, "enable_super_kernel":true}, "multistream_overlap_shared_expert":true}'`. ### How was this patch tested? E2E deepseek serving with 2P1D disaggregated prefill scenarios. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -58,6 +58,7 @@ The details of each config option are as follows:
|
|||||||
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
|
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
|
||||||
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
|
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
|
||||||
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). |
|
||||||
|
| `enable_super_kernel` | bool | `False` | Whether to enable super kernel to fuse operators in deepseek moe layers. This option only takes effects on moe models using dynamic w8a8 quantization.|
|
||||||
|
|
||||||
**ascend_scheduler_config**
|
**ascend_scheduler_config**
|
||||||
|
|
||||||
|
|||||||
@@ -56,17 +56,18 @@ class TestAscendUnquantizedLinearMethod(TestBase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.method = AscendUnquantizedLinearMethod()
|
self.method = AscendUnquantizedLinearMethod()
|
||||||
|
self.layer = mock.MagicMock()
|
||||||
|
mock_dtype = mock.PropertyMock(return_value=torch.float16)
|
||||||
|
type(self.layer.weight.data).dtype = mock_dtype
|
||||||
|
|
||||||
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
||||||
@mock.patch("torch_npu.npu_format_cast")
|
@mock.patch("torch_npu.npu_format_cast")
|
||||||
@mock.patch("torch.version")
|
@mock.patch("torch.version")
|
||||||
def test_process_weights_after_loading_is_8_3_enable_nz(
|
def test_process_weights_after_loading_is_8_3_enable_nz(
|
||||||
self, mock_version, mock_format_cast, mock_is_nz):
|
self, mock_version, mock_format_cast, mock_is_nz):
|
||||||
layer = mock.MagicMock()
|
|
||||||
|
|
||||||
mock_version.cann = "8.3.RC1"
|
mock_version.cann = "8.3.RC1"
|
||||||
mock_is_nz.return_value = 1
|
mock_is_nz.return_value = 1
|
||||||
self.method.process_weights_after_loading(layer)
|
self.method.process_weights_after_loading(self.layer)
|
||||||
mock_format_cast.assert_called_once()
|
mock_format_cast.assert_called_once()
|
||||||
|
|
||||||
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
||||||
@@ -74,23 +75,19 @@ class TestAscendUnquantizedLinearMethod(TestBase):
|
|||||||
@mock.patch("torch.version")
|
@mock.patch("torch.version")
|
||||||
def test_process_weights_after_loading_is_8_3_disable_nz(
|
def test_process_weights_after_loading_is_8_3_disable_nz(
|
||||||
self, mock_version, mock_format_cast, mock_is_nz):
|
self, mock_version, mock_format_cast, mock_is_nz):
|
||||||
layer = mock.MagicMock()
|
|
||||||
|
|
||||||
mock_version.cann = "8.3.RC1"
|
mock_version.cann = "8.3.RC1"
|
||||||
mock_is_nz.return_value = 0
|
mock_is_nz.return_value = 0
|
||||||
self.method.process_weights_after_loading(layer)
|
self.method.process_weights_after_loading(self.layer)
|
||||||
mock_format_cast.assert_not_called()
|
mock_format_cast.assert_not_called()
|
||||||
|
|
||||||
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
|
||||||
@mock.patch("torch.version")
|
@mock.patch("torch.version")
|
||||||
def test_process_weights_after_loading_not_8_3(self, mock_version,
|
def test_process_weights_after_loading_not_8_3(self, mock_version,
|
||||||
mock_is_nz):
|
mock_is_nz):
|
||||||
layer = mock.MagicMock()
|
|
||||||
|
|
||||||
mock_version.cann = "8.2.RC1"
|
mock_version.cann = "8.2.RC1"
|
||||||
mock_is_nz.return_value = 1
|
mock_is_nz.return_value = 1
|
||||||
# Should not raise exception
|
# Should not raise exception
|
||||||
self.method.process_weights_after_loading(layer)
|
self.method.process_weights_after_loading(self.layer)
|
||||||
|
|
||||||
|
|
||||||
class TestAscendRowParallelLinear(BaseLinearTest):
|
class TestAscendRowParallelLinear(BaseLinearTest):
|
||||||
|
|||||||
@@ -37,7 +37,8 @@ class AscendConfig:
|
|||||||
|
|
||||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||||
{})
|
{})
|
||||||
self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config)
|
self.torchair_graph_config = TorchairGraphConfig(
|
||||||
|
torchair_graph_config, vllm_config, additional_config)
|
||||||
|
|
||||||
ascend_scheduler_config = additional_config.get(
|
ascend_scheduler_config = additional_config.get(
|
||||||
"ascend_scheduler_config", {})
|
"ascend_scheduler_config", {})
|
||||||
@@ -133,7 +134,7 @@ class TorchairGraphConfig:
|
|||||||
Configuration Object for torchair_graph_config from additional_config
|
Configuration Object for torchair_graph_config from additional_config
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, torchair_graph_config):
|
def __init__(self, torchair_graph_config, vllm_config, additional_config):
|
||||||
self.enabled = torchair_graph_config.get("enabled", False)
|
self.enabled = torchair_graph_config.get("enabled", False)
|
||||||
self.mode = torchair_graph_config.get("mode", '')
|
self.mode = torchair_graph_config.get("mode", '')
|
||||||
self.use_cached_graph = torchair_graph_config.get(
|
self.use_cached_graph = torchair_graph_config.get(
|
||||||
@@ -151,6 +152,8 @@ class TorchairGraphConfig:
|
|||||||
self.enable_frozen_parameter = torchair_graph_config.get(
|
self.enable_frozen_parameter = torchair_graph_config.get(
|
||||||
"enable_frozen_parameter", True)
|
"enable_frozen_parameter", True)
|
||||||
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
|
||||||
|
self.enable_super_kernel = torchair_graph_config.get(
|
||||||
|
"enable_super_kernel", False)
|
||||||
|
|
||||||
if not isinstance(self.graph_batch_sizes, list):
|
if not isinstance(self.graph_batch_sizes, list):
|
||||||
raise TypeError("graph_batch_sizes must be list[int]")
|
raise TypeError("graph_batch_sizes must be list[int]")
|
||||||
@@ -186,6 +189,20 @@ class TorchairGraphConfig:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
||||||
)
|
)
|
||||||
|
if self.enable_super_kernel:
|
||||||
|
raise RuntimeError(
|
||||||
|
"enable_super_kernel is valid only when Torchair graph mode is enabled"
|
||||||
|
)
|
||||||
|
if self.enable_super_kernel:
|
||||||
|
if vllm_config.parallel_config.tensor_parallel_size != 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
"enable_super_kernel is valid only when tensor_parallel_size is 1"
|
||||||
|
)
|
||||||
|
if not additional_config.get("multistream_overlap_shared_expert",
|
||||||
|
False):
|
||||||
|
raise RuntimeError(
|
||||||
|
"enable_super_kernel is valid only when multistream_overlap_shared_expert is enabled"
|
||||||
|
)
|
||||||
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
|
if self.use_cached_kv_cache_bytes and not self.use_cached_graph:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
|
"use_cached_kv_cache_bytes is valid only when Torchair graph mode and use_cached_graph are enabled"
|
||||||
|
|||||||
@@ -44,7 +44,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
super().process_weights_after_loading(layer)
|
super().process_weights_after_loading(layer)
|
||||||
if is_enable_nz() and torch.version.cann.startswith("8.3"):
|
if (is_enable_nz() and torch.version.cann.startswith("8.3") and
|
||||||
|
layer.weight.data.dtype in [torch.float16, torch.bfloat16]):
|
||||||
layer.weight.data = torch_npu.npu_format_cast(
|
layer.weight.data = torch_npu.npu_format_cast(
|
||||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
|
|||||||
@@ -328,14 +328,22 @@ class TorchairDeepseekV2MoE(nn.Module):
|
|||||||
ascend_config.multistream_overlap_shared_expert and \
|
ascend_config.multistream_overlap_shared_expert and \
|
||||||
self.torchair_graph_enabled
|
self.torchair_graph_enabled
|
||||||
|
|
||||||
|
self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel
|
||||||
|
self.params_dtype = torch.float32 if self.enable_super_kernel else \
|
||||||
|
torch.get_default_dtype()
|
||||||
|
# Converting gate weight to fp32 is to adapt to the super kernel feature.
|
||||||
|
# Super kernel feature currently cannot fuse operators such as cast, stridedslice, and add.
|
||||||
|
# In the moe stage, Cast will interrupt the fusion of the super kernel. To avoid this problem,
|
||||||
|
# modifications will be made in the initialization stage.
|
||||||
self.gate = ReplicatedLinear(config.hidden_size,
|
self.gate = ReplicatedLinear(config.hidden_size,
|
||||||
config.n_routed_experts,
|
config.n_routed_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=None,
|
quant_config=None,
|
||||||
|
params_dtype=self.params_dtype,
|
||||||
prefix=f"{prefix}.gate")
|
prefix=f"{prefix}.gate")
|
||||||
if config.topk_method == "noaux_tc":
|
if config.topk_method == "noaux_tc":
|
||||||
self.gate.e_score_correction_bias = nn.Parameter(
|
self.gate.e_score_correction_bias = nn.Parameter(
|
||||||
torch.empty(config.n_routed_experts))
|
torch.empty(config.n_routed_experts, dtype=self.params_dtype))
|
||||||
else:
|
else:
|
||||||
self.gate.e_score_correction_bias = None
|
self.gate.e_score_correction_bias = None
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
|||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||||
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
|
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
|
||||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
|
||||||
|
super_kernel)
|
||||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||||
get_all_reduce_merge_state,
|
get_all_reduce_merge_state,
|
||||||
get_ascend_soc_version,
|
get_ascend_soc_version,
|
||||||
@@ -990,6 +991,7 @@ class TorchairAscendFusedMoE(FusedMoE):
|
|||||||
)
|
)
|
||||||
TorchairAscendFusedMoE.moe_counter += 1
|
TorchairAscendFusedMoE.moe_counter += 1
|
||||||
self.moe_instance_id = TorchairAscendFusedMoE.moe_counter
|
self.moe_instance_id = TorchairAscendFusedMoE.moe_counter
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
@@ -1096,6 +1098,7 @@ class TorchairAscendFusedMoE(FusedMoE):
|
|||||||
self.multistream_overlap_shared_expert = \
|
self.multistream_overlap_shared_expert = \
|
||||||
ascend_config.multistream_overlap_shared_expert and \
|
ascend_config.multistream_overlap_shared_expert and \
|
||||||
self.torchair_graph_enabled
|
self.torchair_graph_enabled
|
||||||
|
self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
|
||||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||||
@@ -1192,8 +1195,16 @@ class TorchairAscendFusedMoE(FusedMoE):
|
|||||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||||
TorchairAscendW8A8DynamicFusedMoEMethod
|
TorchairAscendW8A8DynamicFusedMoEMethod
|
||||||
|
running_in_super_kernel = self.enable_super_kernel and fused_moe_state == FusedMoEState.MC2
|
||||||
|
|
||||||
if self.multistream_overlap_shared_expert:
|
if self.multistream_overlap_shared_expert:
|
||||||
|
with super_kernel(self.prefix,
|
||||||
|
"stream-fusion=1",
|
||||||
|
enabled=running_in_super_kernel):
|
||||||
if not self.rm_router_logits:
|
if not self.rm_router_logits:
|
||||||
|
if self.enable_super_kernel:
|
||||||
|
router_logits, _ = gate(hidden_states.float())
|
||||||
|
else:
|
||||||
router_logits, _ = gate(hidden_states)
|
router_logits, _ = gate(hidden_states)
|
||||||
if hasattr(self.quant_method, "quant_method") and \
|
if hasattr(self.quant_method, "quant_method") and \
|
||||||
isinstance(self.quant_method.quant_method,
|
isinstance(self.quant_method.quant_method,
|
||||||
@@ -1305,6 +1316,8 @@ class TorchairAscendFusedMoE(FusedMoE):
|
|||||||
mc2_mask=mc2_mask,
|
mc2_mask=mc2_mask,
|
||||||
quantized_x_for_share=quantized_x_for_share,
|
quantized_x_for_share=quantized_x_for_share,
|
||||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||||
|
prefix=self.prefix,
|
||||||
|
running_in_super_kernel=running_in_super_kernel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if shared_experts:
|
if shared_experts:
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ from vllm_ascend.ascend_config import get_ascend_config
|
|||||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
|
||||||
|
super_kernel)
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||||
dispose_tensor, get_ascend_soc_version,
|
dispose_tensor, get_ascend_soc_version,
|
||||||
is_enable_nz,
|
is_enable_nz,
|
||||||
@@ -927,6 +928,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
quantized_x_for_share: Optional[Any] = None,
|
quantized_x_for_share: Optional[Any] = None,
|
||||||
dynamic_scale_for_share: Optional[Any] = None,
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
running_in_super_kernel: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
@@ -934,6 +937,14 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
|
|
||||||
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
|
||||||
|
|
||||||
|
fused_moe_state = get_forward_context().fused_moe_state
|
||||||
|
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
|
||||||
|
fused_moe_state = FusedMoEState.All2All
|
||||||
|
shared_gate_up, shared_dequant_scale = None, None
|
||||||
|
|
||||||
|
with super_kernel(prefix,
|
||||||
|
"stream-fusion=1",
|
||||||
|
enabled=running_in_super_kernel):
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
if is_deepseek_v3_r1:
|
if is_deepseek_v3_r1:
|
||||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
@@ -964,10 +975,6 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
fused_moe_state = get_forward_context().fused_moe_state
|
|
||||||
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
|
|
||||||
fused_moe_state = FusedMoEState.All2All
|
|
||||||
shared_gate_up, shared_dequant_scale = None, None
|
|
||||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||||
with npu_stream_switch("moe_secondary", 0):
|
with npu_stream_switch("moe_secondary", 0):
|
||||||
npu_wait_tensor(quantized_x_for_share, router_logits)
|
npu_wait_tensor(quantized_x_for_share, router_logits)
|
||||||
@@ -981,8 +988,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
# currently it is only activated when doing profile runs.
|
# currently it is only activated when doing profile runs.
|
||||||
if enable_force_load_balance:
|
if enable_force_load_balance:
|
||||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||||
|
|
||||||
topk_weights = topk_weights.to(x.dtype)
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
|
||||||
if fused_moe_state == FusedMoEState.AllGatherEP:
|
if fused_moe_state == FusedMoEState.AllGatherEP:
|
||||||
return torchair_fused_experts_with_allgather(
|
return torchair_fused_experts_with_allgather(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -995,6 +1002,9 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
elif fused_moe_state == FusedMoEState.MC2:
|
elif fused_moe_state == FusedMoEState.MC2:
|
||||||
|
with super_kernel(prefix,
|
||||||
|
"stream-fusion=1",
|
||||||
|
enabled=running_in_super_kernel):
|
||||||
return torchair_fused_experts_with_mc2(
|
return torchair_fused_experts_with_mc2(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
from torchair.scope import super_kernel as _super_kernel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Recent release of torchair has moved these ops to `.scope`.
|
# Recent release of torchair has moved these ops to `.scope`.
|
||||||
@@ -231,3 +232,7 @@ def torchair_ops_patch():
|
|||||||
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
|
AscendRMSNorm.forward_oot = torchair_layernorm.torchair_rmsnorm_forward_oot # type: ignore[method-assign]
|
||||||
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]
|
AscendSiluAndMul.forward_oot = torchair_activation.torchair_silu_and_mul_forward_oot # type: ignore[method-assign]
|
||||||
AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]
|
AscendVocabParallelEmbedding.forward = vocab_embedding_forward # type: ignore[method-assign]
|
||||||
|
|
||||||
|
|
||||||
|
def super_kernel(prefix: str, option: str, enabled: bool = True):
|
||||||
|
return _super_kernel(prefix, option) if enabled else nullcontext()
|
||||||
Reference in New Issue
Block a user