[feature]Ascend quantization support (#7791)
Co-authored-by: ichernob <ichernobnn@gmail.com> Co-authored-by: liupeng <liupeng374@huawei.com>
This commit is contained in:
@@ -413,7 +413,9 @@ class ModelConfig:
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
|
||||
if quant_cfg is not None:
|
||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||
quant_method = quant_cfg.get(
|
||||
"quant_method", "" if not self.quantization else self.quantization
|
||||
).lower()
|
||||
|
||||
# Detect which checkpoint is it
|
||||
for _, method in QUANTIZATION_METHODS.items():
|
||||
|
||||
@@ -34,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
use_intel_amx_backend,
|
||||
)
|
||||
@@ -60,6 +61,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_is_npu = is_npu()
|
||||
|
||||
|
||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
@@ -297,6 +299,14 @@ class ReplicatedLinear(LinearBase):
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
# The per-tensor quant-scale must be 1 dimension
|
||||
if _is_npu:
|
||||
if param.size() != loaded_weight.size() and param.size(0) == 1:
|
||||
if torch.allclose(loaded_weight, loaded_weight[0]):
|
||||
loaded_weight = loaded_weight[:1]
|
||||
else:
|
||||
raise ValueError(f"{loaded_weight} are not all equal")
|
||||
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
ep_gather,
|
||||
ep_scatter,
|
||||
@@ -65,6 +64,8 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
if not _is_npu:
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
|
||||
if _is_hip:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
@@ -850,7 +850,7 @@ class FusedMoE(torch.nn.Module):
|
||||
return
|
||||
|
||||
# Case weight scales and zero_points
|
||||
if "scale" in weight_name or "zero" in weight_name:
|
||||
if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
|
||||
# load the weight scales and zp based on the quantization scheme
|
||||
# supported weight scales/zp can be found in
|
||||
# FusedMoeWeightScaleSupported
|
||||
|
||||
@@ -308,7 +308,7 @@ def biased_grouped_topk_gpu(
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
compiled: bool = True,
|
||||
compiled: bool = not _is_npu,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
|
||||
@@ -116,8 +116,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
||||
if can_convert and user_quant == "moe_wna16":
|
||||
if user_quant == "moe_wna16" and cls.is_moe_wna16_compatible(hf_quant_cfg):
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,21 +1,37 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
import importlib
|
||||
import sys
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||
from sglang.srt.layers.linear import LinearMethodBase
|
||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||
from sglang.srt.layers.linear import (
|
||||
LinearMethodBase,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.utils import (
|
||||
apply_module_patch,
|
||||
cpu_has_amx_support,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
use_intel_amx_backend,
|
||||
)
|
||||
@@ -25,6 +41,134 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
_is_npu = is_npu()
|
||||
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
try:
|
||||
from mindie_turbo import _ops as ops
|
||||
from mindie_turbo.quantize.quant_utils import quant_per_tensor
|
||||
except ImportError:
|
||||
useMindIETurbo = False
|
||||
else:
|
||||
useMindIETurbo = True
|
||||
|
||||
|
||||
# func refers to RMSNorm.__init__
|
||||
def npu_wrapper_rmsnorm_init(func):
|
||||
def init(self, hidden_size: int, **extra_args) -> None:
|
||||
func(self, hidden_size, **extra_args)
|
||||
self.ignore_anti = True
|
||||
# The Ascend w8a8_int8 quantization requires adding a bias in rmsnorm
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
|
||||
|
||||
return init
|
||||
|
||||
|
||||
# func refers to RMSNorm.forward_oot
|
||||
def npu_wrapper_rmsnorm_forward(func):
|
||||
def _rmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
original_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(original_dtype)
|
||||
|
||||
x = (
|
||||
torch_npu.npu_rms_norm(
|
||||
x, self.weight.to(torch.float32), self.variance_epsilon
|
||||
)[0]
|
||||
+ self.bias
|
||||
)
|
||||
|
||||
if residual is None:
|
||||
return x.to(original_dtype)
|
||||
return x.to(original_dtype), residual
|
||||
|
||||
return _rmsnorm_forward_oot
|
||||
|
||||
|
||||
def npu_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
):
|
||||
original_shape = hidden_states.shape
|
||||
original_dtype = hidden_states.dtype
|
||||
scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
|
||||
if len(original_shape) == 3:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
num_tokens = hidden_states.shape[0]
|
||||
num_experts = w13.shape[0]
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (
|
||||
torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
|
||||
.view(top_k, -1)
|
||||
.permute(1, 0)
|
||||
.contiguous()
|
||||
)
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = (
|
||||
torch_npu.npu_moe_init_routing(
|
||||
hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
|
||||
)
|
||||
)
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, num_experts
|
||||
)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w13],
|
||||
scale=[w13_scale.to(scale_dtype)],
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale.to(scale_dtype)],
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
if len(original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(original_shape)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class W8A8Int8Config(QuantizationConfig):
|
||||
@@ -34,16 +178,47 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
- Activation: dynamic, per-token, symmetric
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, quant_config: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self.quant_description = quant_config
|
||||
self.is_dynamic = quant_config.get("is_dynamic", False)
|
||||
if _is_npu:
|
||||
if (
|
||||
"packed_modules_mapping" in quant_config
|
||||
and quant_config["packed_modules_mapping"] is not None
|
||||
):
|
||||
self.packed_modules_mapping = quant_config["packed_modules_mapping"]
|
||||
|
||||
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
|
||||
for name in self.quant_description.keys():
|
||||
if "norm.bias" in name:
|
||||
apply_module_patch(
|
||||
"sglang.srt.layers.layernorm.RMSNorm",
|
||||
"__init__",
|
||||
[npu_wrapper_rmsnorm_init],
|
||||
)
|
||||
apply_module_patch(
|
||||
"sglang.srt.layers.layernorm.RMSNorm",
|
||||
"forward_npu",
|
||||
[npu_wrapper_rmsnorm_forward],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
return (
|
||||
[torch.float16, torch.bfloat16]
|
||||
if not _is_npu
|
||||
else [torch.int8, torch.float16, torch.bfloat16]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
if _is_npu:
|
||||
raise NotImplementedError(
|
||||
'NPU hardware does not support "get_min_capability" feature.'
|
||||
)
|
||||
else:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
@@ -55,7 +230,7 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
|
||||
return cls()
|
||||
return cls(config)
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
@@ -65,11 +240,65 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return W8A8Int8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return W8A8Int8MoEMethod(self)
|
||||
return None
|
||||
if _is_npu:
|
||||
if isinstance(layer, LinearBase):
|
||||
prefix_in_quant_config = prefix
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
prefix_in_quant_config = prefix.replace(
|
||||
proj_name, self.packed_modules_mapping[proj_name][0]
|
||||
)
|
||||
self.is_dynamic = (
|
||||
self.quant_description[prefix_in_quant_config + ".weight"]
|
||||
== "W8A8_DYNAMIC"
|
||||
)
|
||||
if self.is_layer_skipped(prefix, self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return (
|
||||
NPU_W8A8DynamicLinearMethod(self)
|
||||
if self.is_dynamic
|
||||
else NPU_W8A8LinearMethod(self)
|
||||
)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return NPU_W8A8MoEMethod(self)
|
||||
return None
|
||||
else:
|
||||
if isinstance(layer, LinearBase):
|
||||
return W8A8Int8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return W8A8Int8MoEMethod(self)
|
||||
return None
|
||||
|
||||
def is_layer_skipped(
|
||||
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
):
|
||||
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = (
|
||||
self.quant_description[shard_prefix + ".weight"] == "FLOAT"
|
||||
)
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision."
|
||||
)
|
||||
else:
|
||||
is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT"
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
@@ -321,3 +550,498 @@ class W8A8Int8MoEMethod:
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
class NPU_W8A8LinearMethodImpl:
|
||||
"""Linear method for NPU W8A8."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
||||
self.transpose_weight = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
||||
if params_dtype == torch.bfloat16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
|
||||
elif params_dtype == torch.float16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
|
||||
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
if original_dtype != torch.int8:
|
||||
x = torch_npu.npu_quantize(
|
||||
x,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
torch.qint8,
|
||||
-1,
|
||||
True,
|
||||
)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
return torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
expanding_factor = layer.weight.data.shape[1]
|
||||
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
|
||||
requires_grad=False,
|
||||
)
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
|
||||
|
||||
class NPU_W8A8LinearMethodMTImpl:
|
||||
"""Linear method for NPU W8A8."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.transpose_weight = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
||||
if params_dtype == torch.bfloat16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
|
||||
elif params_dtype == torch.float16:
|
||||
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
|
||||
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
if original_dtype != torch.int8:
|
||||
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
return ops.quant_matmul(
|
||||
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
layer.aclnn_deq_scale = torch.nn.Parameter(
|
||||
torch_npu.npu_trans_quant_param(layer.deq_scale.npu()).to(device="npu"),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
|
||||
class NPU_W8A8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for NPU quantization.
|
||||
|
||||
This class search for specific quantization
|
||||
implementation supported on NPU hardware for linear methods.
|
||||
|
||||
Args:
|
||||
quant_config: The NPU quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quantization_config: W8A8Int8Config) -> None:
|
||||
self.quantization_config = quantization_config
|
||||
self.quant_method = (
|
||||
NPU_W8A8LinearMethodMTImpl()
|
||||
if useMindIETurbo
|
||||
else NPU_W8A8LinearMethodImpl()
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
weight_dict = self.quant_method.get_weight(
|
||||
input_size_per_partition, output_size_per_partition, params_dtype
|
||||
)
|
||||
for weight_name, weight_param in weight_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter(weight_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||
param = PerTensorScaleParameter(
|
||||
data=pertensor_param, weight_loader=weight_loader
|
||||
)
|
||||
# disable warning
|
||||
param.ignore_warning = True
|
||||
layer.register_parameter(pertensor_name, param)
|
||||
|
||||
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||
output_size_per_partition, params_dtype
|
||||
)
|
||||
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(perchannel_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||
return self.quant_method.apply(layer, x, bias)
|
||||
|
||||
|
||||
class NPU_W8A8DynamicLinearMethodImpl:
|
||||
"""Linear method for NPU W8A8_DYNAMIC."""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
input_size: int, output_size: int, params_dtype: torch.dtype
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
params_dict = {}
|
||||
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
return params_dict
|
||||
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
# use ATB quantize
|
||||
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
return torch_npu.npu_quant_matmul(
|
||||
quant_out,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
bias=bias,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
|
||||
|
||||
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
||||
"""Linear method for NPU quantization.
|
||||
|
||||
This class search for specific quantization
|
||||
implementations supported on NPU hardware for linear methods.
|
||||
|
||||
Args:
|
||||
quant_config: The NPU quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quantization_config: W8A8Int8Config) -> None:
|
||||
self.quantization_config = quantization_config
|
||||
self.quant_method = NPU_W8A8DynamicLinearMethodImpl()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
weight_dict = self.quant_method.get_weight(
|
||||
input_size_per_partition, output_size_per_partition, params_dtype
|
||||
)
|
||||
for weight_name, weight_param in weight_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter(weight_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||
param = PerTensorScaleParameter(
|
||||
data=pertensor_param, weight_loader=weight_loader
|
||||
)
|
||||
# disable warning
|
||||
param.ignore_warning = True
|
||||
layer.register_parameter(pertensor_name, param)
|
||||
|
||||
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||
output_size_per_partition, params_dtype
|
||||
)
|
||||
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
||||
set_weight_attrs(param, {"output_dim": 0})
|
||||
layer.register_parameter(perchannel_name, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(layer, RowParallelLinear):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||
return self.quant_method.apply(layer, x, bias)
|
||||
|
||||
|
||||
class NPU_W8A8MoEMethod:
|
||||
"""MoE method for NPU quantization.
|
||||
|
||||
This class search for specific quantization
|
||||
implementations supported on NPU hardware for moe methods.
|
||||
|
||||
Args:
|
||||
quant_config: The NPU quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quantization_config: W8A8Int8Config) -> None:
|
||||
self.quantization_config = quantization_config
|
||||
self.quant_method = self
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: List[int],
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
self.num_experts = num_experts
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||
)
|
||||
|
||||
# weight
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
# scale
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
# offset
|
||||
w13_weight_offset = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_offset", w13_weight_offset)
|
||||
set_weight_attrs(w13_weight_offset, extra_weight_attrs)
|
||||
w2_weight_offset = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_offset", w2_weight_offset)
|
||||
set_weight_attrs(w2_weight_offset, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.w13_weight = Parameter(
|
||||
layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False
|
||||
)
|
||||
layer.w2_weight = Parameter(
|
||||
layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False
|
||||
)
|
||||
layer.w13_weight_scale = Parameter(
|
||||
layer.w13_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = Parameter(
|
||||
layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
|
||||
)
|
||||
layer.w13_weight_offset = Parameter(
|
||||
layer.w13_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
||||
)
|
||||
layer.w2_weight_offset = Parameter(
|
||||
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
routed_scaling_factor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
global_num_experts = router_logits.shape[-1]
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k,
|
||||
bias=correction_bias,
|
||||
k_group=topk_group,
|
||||
group_count=num_expert_group,
|
||||
group_select_mode=1,
|
||||
renorm=0,
|
||||
norm_type=1,
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20),
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
torch_native=True,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
return npu_fused_experts(
|
||||
hidden_states=x,
|
||||
w13=layer.w13_weight,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
@@ -34,16 +34,18 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GB = 1024 * 1024 * 1024
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
if not _is_npu:
|
||||
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
||||
|
||||
|
||||
class ReqToTokenPool:
|
||||
|
||||
@@ -64,10 +64,13 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_device_capability,
|
||||
is_npu,
|
||||
is_pin_memory_available,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
||||
@@ -127,18 +130,19 @@ def _get_quantization_config(
|
||||
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||
if quant_config is None:
|
||||
return None
|
||||
major, minor = get_device_capability()
|
||||
if not _is_npu:
|
||||
major, minor = get_device_capability()
|
||||
|
||||
if major is not None and minor is not None:
|
||||
assert 0 <= minor < 10
|
||||
capability = major * 10 + minor
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} "
|
||||
"is not supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}."
|
||||
)
|
||||
if major is not None and minor is not None:
|
||||
assert 0 <= minor < 10
|
||||
capability = major * 10 + minor
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} "
|
||||
"is not supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}."
|
||||
)
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
@@ -157,6 +161,13 @@ def _initialize_model(
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
||||
if _is_npu:
|
||||
packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
|
||||
"q_a_proj",
|
||||
"kv_a_proj_with_mqa",
|
||||
]
|
||||
packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
|
||||
packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
)
|
||||
|
||||
@@ -575,6 +575,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
@@ -407,6 +407,8 @@ class QuantMixtralForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
@@ -418,6 +420,8 @@ class QuantMixtralForCausalLM(nn.Module):
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if "block_sparse_moe.experts." in name and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@@ -538,6 +538,8 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
@@ -197,7 +197,7 @@ def get_int_env_var(name: str, default: int = 0) -> int:
|
||||
|
||||
|
||||
def support_triton(backend: str) -> bool:
|
||||
return backend not in ["torch_native", "intel_amx"]
|
||||
return backend not in ["torch_native", "intel_amx", "ascend"]
|
||||
|
||||
|
||||
try:
|
||||
@@ -2782,3 +2782,101 @@ def lru_cache_frozenset(maxsize=128):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def apply_module_patch(target_module, target_function, wrappers):
|
||||
original_module, original_function = parse_module_path(
|
||||
target_module, target_function, False
|
||||
)
|
||||
|
||||
original_function_id = id(original_function)
|
||||
|
||||
candidate = original_function
|
||||
for wrapper in wrappers:
|
||||
candidate = wrapper(candidate)
|
||||
if target_function is not None:
|
||||
setattr(original_module, target_function, candidate)
|
||||
|
||||
for key, value in sys.modules.copy().items():
|
||||
if (
|
||||
target_function is not None
|
||||
and hasattr(value, target_function)
|
||||
and id(getattr(value, target_function)) == original_function_id
|
||||
):
|
||||
setattr(value, target_function, candidate)
|
||||
|
||||
|
||||
def parse_module_path(module_path, function_name, create_dummy):
|
||||
from importlib.machinery import ModuleSpec
|
||||
|
||||
def create_dummy_module(full_path, parent=None):
|
||||
"""Create and register a placeholder module"""
|
||||
dummy = types.ModuleType(full_path)
|
||||
dummy.__file__ = "vllm_ascend.dummy_module.py"
|
||||
dummy.__spec__ = ModuleSpec(full_path, None)
|
||||
sys.modules[full_path] = dummy
|
||||
if parent:
|
||||
setattr(parent, full_path.split(".")[-1], dummy)
|
||||
return dummy
|
||||
|
||||
def create_placeholder_function(func_name):
|
||||
"""Create dummy function that raises when called"""
|
||||
|
||||
def placeholder(*args, **kwargs):
|
||||
raise NotImplementedError(f"Function {func_name} is a placeholder")
|
||||
|
||||
placeholder.__name__ = func_name
|
||||
return placeholder
|
||||
|
||||
modules = module_path.split(".")
|
||||
current_module = None
|
||||
processed_path = []
|
||||
|
||||
for idx, part in enumerate(modules):
|
||||
current_path = ".".join(modules[: idx + 1])
|
||||
parent_path = ".".join(modules[:idx]) if idx > 0 else None
|
||||
|
||||
try:
|
||||
current_module = importlib.import_module(current_path)
|
||||
except ModuleNotFoundError:
|
||||
# Handle missing module
|
||||
parent = importlib.import_module(parent_path) if parent_path else None
|
||||
if parent and hasattr(parent, part):
|
||||
# Use existing attribute from parent
|
||||
current_module = getattr(parent, part)
|
||||
# Check for early function resolution
|
||||
if function_name and hasattr(current_module, function_name):
|
||||
return current_module, getattr(current_module, function_name)
|
||||
if function_name and create_dummy:
|
||||
ph_func = create_placeholder_function(function_name)
|
||||
setattr(current_module, function_name, ph_func)
|
||||
return current_module, ph_func
|
||||
if function_name:
|
||||
raise AttributeError(
|
||||
f"Function {function_name} missing in {current_path}"
|
||||
)
|
||||
else:
|
||||
if not create_dummy:
|
||||
raise
|
||||
# Create and register dummy module
|
||||
current_module = create_dummy_module(
|
||||
current_path,
|
||||
parent=(
|
||||
importlib.import_module(parent_path) if parent_path else None
|
||||
),
|
||||
)
|
||||
|
||||
processed_path.append(part)
|
||||
|
||||
# Final function handling
|
||||
final_module = sys.modules[module_path]
|
||||
if function_name is not None:
|
||||
if not hasattr(final_module, function_name):
|
||||
if create_dummy:
|
||||
ph_func = create_placeholder_function(function_name)
|
||||
setattr(final_module, function_name, ph_func)
|
||||
else:
|
||||
setattr(final_module, function_name, None)
|
||||
return final_module, getattr(final_module, function_name)
|
||||
|
||||
return final_module, None
|
||||
|
||||
Reference in New Issue
Block a user