[feature]Ascend quantization support (#7791)
Co-authored-by: ichernob <ichernobnn@gmail.com> Co-authored-by: liupeng <liupeng374@huawei.com>
This commit is contained in:
@@ -413,7 +413,9 @@ class ModelConfig:
|
|||||||
quant_cfg = self._parse_quant_hf_config()
|
quant_cfg = self._parse_quant_hf_config()
|
||||||
|
|
||||||
if quant_cfg is not None:
|
if quant_cfg is not None:
|
||||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
quant_method = quant_cfg.get(
|
||||||
|
"quant_method", "" if not self.quantization else self.quantization
|
||||||
|
).lower()
|
||||||
|
|
||||||
# Detect which checkpoint is it
|
# Detect which checkpoint is it
|
||||||
for _, method in QUANTIZATION_METHODS.items():
|
for _, method in QUANTIZATION_METHODS.items():
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
|
is_npu,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
use_intel_amx_backend,
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
@@ -60,6 +61,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|||||||
|
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
_is_npu = is_npu()
|
||||||
|
|
||||||
|
|
||||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||||
@@ -297,6 +299,14 @@ class ReplicatedLinear(LinearBase):
|
|||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
|
# The per-tensor quant-scale must be 1 dimension
|
||||||
|
if _is_npu:
|
||||||
|
if param.size() != loaded_weight.size() and param.size(0) == 1:
|
||||||
|
if torch.allclose(loaded_weight, loaded_weight[0]):
|
||||||
|
loaded_weight = loaded_weight[:1]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{loaded_weight} are not all equal")
|
||||||
|
|
||||||
assert param.size() == loaded_weight.size()
|
assert param.size() == loaded_weight.size()
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from sglang.srt.distributed import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
ep_gather,
|
ep_gather,
|
||||||
ep_scatter,
|
ep_scatter,
|
||||||
@@ -65,6 +64,8 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|||||||
if not _is_npu:
|
if not _is_npu:
|
||||||
from sgl_kernel import silu_and_mul
|
from sgl_kernel import silu_and_mul
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
|
|||||||
@@ -850,7 +850,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Case weight scales and zero_points
|
# Case weight scales and zero_points
|
||||||
if "scale" in weight_name or "zero" in weight_name:
|
if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
|
||||||
# load the weight scales and zp based on the quantization scheme
|
# load the weight scales and zp based on the quantization scheme
|
||||||
# supported weight scales/zp can be found in
|
# supported weight scales/zp can be found in
|
||||||
# FusedMoeWeightScaleSupported
|
# FusedMoeWeightScaleSupported
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ def biased_grouped_topk_gpu(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
compiled: bool = True,
|
compiled: bool = not _is_npu,
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
|
|||||||
@@ -116,8 +116,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||||
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
if user_quant == "moe_wna16" and cls.is_moe_wna16_compatible(hf_quant_cfg):
|
||||||
if can_convert and user_quant == "moe_wna16":
|
|
||||||
return cls.get_name()
|
return cls.get_name()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,37 @@
|
|||||||
from typing import Any, Callable, Dict, List, Optional
|
import importlib
|
||||||
|
import sys
|
||||||
|
from types import MappingProxyType
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import (
|
||||||
|
get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
from sglang.srt.layers.linear import LinearMethodBase
|
from sglang.srt.layers.linear import (
|
||||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
LinearMethodBase,
|
||||||
|
RowParallelLinear,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.parameter import (
|
||||||
|
ChannelQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
PerTensorScaleParameter,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
|
apply_module_patch,
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
|
is_npu,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
use_intel_amx_backend,
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
@@ -25,6 +41,134 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
|||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import int8_scaled_mm
|
from sgl_kernel import int8_scaled_mm
|
||||||
|
_is_npu = is_npu()
|
||||||
|
|
||||||
|
if _is_npu:
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mindie_turbo import _ops as ops
|
||||||
|
from mindie_turbo.quantize.quant_utils import quant_per_tensor
|
||||||
|
except ImportError:
|
||||||
|
useMindIETurbo = False
|
||||||
|
else:
|
||||||
|
useMindIETurbo = True
|
||||||
|
|
||||||
|
|
||||||
|
# func refers to RMSNorm.__init__
|
||||||
|
def npu_wrapper_rmsnorm_init(func):
|
||||||
|
def init(self, hidden_size: int, **extra_args) -> None:
|
||||||
|
func(self, hidden_size, **extra_args)
|
||||||
|
self.ignore_anti = True
|
||||||
|
# The Ascend w8a8_int8 quantization requires adding a bias in rmsnorm
|
||||||
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
|
||||||
|
|
||||||
|
return init
|
||||||
|
|
||||||
|
|
||||||
|
# func refers to RMSNorm.forward_oot
|
||||||
|
def npu_wrapper_rmsnorm_forward(func):
|
||||||
|
def _rmsnorm_forward_oot(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
if not x.is_contiguous():
|
||||||
|
x = x.contiguous()
|
||||||
|
original_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
if residual is not None:
|
||||||
|
x = x + residual.to(torch.float32)
|
||||||
|
residual = x.to(original_dtype)
|
||||||
|
|
||||||
|
x = (
|
||||||
|
torch_npu.npu_rms_norm(
|
||||||
|
x, self.weight.to(torch.float32), self.variance_epsilon
|
||||||
|
)[0]
|
||||||
|
+ self.bias
|
||||||
|
)
|
||||||
|
|
||||||
|
if residual is None:
|
||||||
|
return x.to(original_dtype)
|
||||||
|
return x.to(original_dtype), residual
|
||||||
|
|
||||||
|
return _rmsnorm_forward_oot
|
||||||
|
|
||||||
|
|
||||||
|
def npu_fused_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w13: torch.Tensor,
|
||||||
|
w13_scale: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
):
|
||||||
|
original_shape = hidden_states.shape
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
|
||||||
|
if len(original_shape) == 3:
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
num_tokens = hidden_states.shape[0]
|
||||||
|
num_experts = w13.shape[0]
|
||||||
|
row_idx_len = num_tokens * top_k
|
||||||
|
row_idx = (
|
||||||
|
torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
|
||||||
|
.view(top_k, -1)
|
||||||
|
.permute(1, 0)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
hidden_states, expanded_row_idx, expanded_expert_idx = (
|
||||||
|
torch_npu.npu_moe_init_routing(
|
||||||
|
hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
|
||||||
|
)
|
||||||
|
)
|
||||||
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||||
|
expanded_expert_idx, num_experts
|
||||||
|
)
|
||||||
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
|
# gmm1: gate_up_proj
|
||||||
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[w13],
|
||||||
|
scale=[w13_scale.to(scale_dtype)],
|
||||||
|
per_token_scale=[pertoken_scale],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=0,
|
||||||
|
group_type=0,
|
||||||
|
group_list=expert_tokens,
|
||||||
|
output_dtype=original_dtype,
|
||||||
|
)[0]
|
||||||
|
# act_fn: swiglu
|
||||||
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||||
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||||
|
# gmm2: down_proj
|
||||||
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[w2],
|
||||||
|
scale=[w2_scale.to(scale_dtype)],
|
||||||
|
per_token_scale=[pertoken_scale],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=0,
|
||||||
|
group_type=0,
|
||||||
|
group_list=expert_tokens,
|
||||||
|
output_dtype=original_dtype,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||||
|
hidden_states,
|
||||||
|
skip1=None,
|
||||||
|
skip2=None,
|
||||||
|
bias=None,
|
||||||
|
scales=topk_weights,
|
||||||
|
expanded_src_to_dst_row=expanded_row_idx,
|
||||||
|
export_for_source_row=topk_ids,
|
||||||
|
)
|
||||||
|
if len(original_shape) == 3:
|
||||||
|
final_hidden_states = final_hidden_states.view(original_shape)
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class W8A8Int8Config(QuantizationConfig):
|
class W8A8Int8Config(QuantizationConfig):
|
||||||
@@ -34,16 +178,47 @@ class W8A8Int8Config(QuantizationConfig):
|
|||||||
- Activation: dynamic, per-token, symmetric
|
- Activation: dynamic, per-token, symmetric
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, quant_config: Dict[str, Any]):
|
||||||
pass
|
super().__init__()
|
||||||
|
self.quant_description = quant_config
|
||||||
|
self.is_dynamic = quant_config.get("is_dynamic", False)
|
||||||
|
if _is_npu:
|
||||||
|
if (
|
||||||
|
"packed_modules_mapping" in quant_config
|
||||||
|
and quant_config["packed_modules_mapping"] is not None
|
||||||
|
):
|
||||||
|
self.packed_modules_mapping = quant_config["packed_modules_mapping"]
|
||||||
|
|
||||||
|
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
|
||||||
|
for name in self.quant_description.keys():
|
||||||
|
if "norm.bias" in name:
|
||||||
|
apply_module_patch(
|
||||||
|
"sglang.srt.layers.layernorm.RMSNorm",
|
||||||
|
"__init__",
|
||||||
|
[npu_wrapper_rmsnorm_init],
|
||||||
|
)
|
||||||
|
apply_module_patch(
|
||||||
|
"sglang.srt.layers.layernorm.RMSNorm",
|
||||||
|
"forward_npu",
|
||||||
|
[npu_wrapper_rmsnorm_forward],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
return [torch.float16, torch.bfloat16]
|
return (
|
||||||
|
[torch.float16, torch.bfloat16]
|
||||||
|
if not _is_npu
|
||||||
|
else [torch.int8, torch.float16, torch.bfloat16]
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 75
|
if _is_npu:
|
||||||
|
raise NotImplementedError(
|
||||||
|
'NPU hardware does not support "get_min_capability" feature.'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return 75
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
@@ -55,7 +230,7 @@ class W8A8Int8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
|
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
|
||||||
return cls()
|
return cls(config)
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self,
|
self,
|
||||||
@@ -65,11 +240,65 @@ class W8A8Int8Config(QuantizationConfig):
|
|||||||
from sglang.srt.layers.linear import LinearBase
|
from sglang.srt.layers.linear import LinearBase
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if _is_npu:
|
||||||
return W8A8Int8LinearMethod(self)
|
if isinstance(layer, LinearBase):
|
||||||
elif isinstance(layer, FusedMoE):
|
prefix_in_quant_config = prefix
|
||||||
return W8A8Int8MoEMethod(self)
|
proj_name = prefix.split(".")[-1]
|
||||||
return None
|
if proj_name in self.packed_modules_mapping:
|
||||||
|
prefix_in_quant_config = prefix.replace(
|
||||||
|
proj_name, self.packed_modules_mapping[proj_name][0]
|
||||||
|
)
|
||||||
|
self.is_dynamic = (
|
||||||
|
self.quant_description[prefix_in_quant_config + ".weight"]
|
||||||
|
== "W8A8_DYNAMIC"
|
||||||
|
)
|
||||||
|
if self.is_layer_skipped(prefix, self.packed_modules_mapping):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
return (
|
||||||
|
NPU_W8A8DynamicLinearMethod(self)
|
||||||
|
if self.is_dynamic
|
||||||
|
else NPU_W8A8LinearMethod(self)
|
||||||
|
)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return NPU_W8A8MoEMethod(self)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return W8A8Int8LinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return W8A8Int8MoEMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_layer_skipped(
|
||||||
|
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||||
|
):
|
||||||
|
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
||||||
|
proj_name = prefix.split(".")[-1]
|
||||||
|
if proj_name in fused_mapping:
|
||||||
|
shard_prefixes = [
|
||||||
|
prefix.replace(proj_name, shard_proj_name)
|
||||||
|
for shard_proj_name in fused_mapping[proj_name]
|
||||||
|
]
|
||||||
|
|
||||||
|
is_skipped = None
|
||||||
|
for shard_prefix in shard_prefixes:
|
||||||
|
is_shard_skipped = (
|
||||||
|
self.quant_description[shard_prefix + ".weight"] == "FLOAT"
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_skipped is None:
|
||||||
|
is_skipped = is_shard_skipped
|
||||||
|
elif is_shard_skipped != is_skipped:
|
||||||
|
raise ValueError(
|
||||||
|
f"Detected some but not all shards of {prefix} "
|
||||||
|
"are quantized. All shards of fused layers "
|
||||||
|
"to have the same precision."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT"
|
||||||
|
|
||||||
|
assert is_skipped is not None
|
||||||
|
return is_skipped
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
return []
|
return []
|
||||||
@@ -321,3 +550,498 @@ class W8A8Int8MoEMethod:
|
|||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NPU_W8A8LinearMethodImpl:
|
||||||
|
"""Linear method for NPU W8A8."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
||||||
|
self.transpose_weight = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_weight(
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
|
params_dict = {}
|
||||||
|
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||||
|
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_perchannel_param(
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
params_dict = {}
|
||||||
|
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
||||||
|
if params_dtype == torch.bfloat16:
|
||||||
|
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
|
||||||
|
elif params_dtype == torch.float16:
|
||||||
|
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
|
||||||
|
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||||
|
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
tp_rank: Optional[int] = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
original_dtype = x.dtype
|
||||||
|
if original_dtype != torch.int8:
|
||||||
|
x = torch_npu.npu_quantize(
|
||||||
|
x,
|
||||||
|
layer.aclnn_input_scale,
|
||||||
|
layer.aclnn_input_offset,
|
||||||
|
torch.qint8,
|
||||||
|
-1,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||||
|
return torch_npu.npu_quant_matmul(
|
||||||
|
x,
|
||||||
|
layer.weight,
|
||||||
|
layer.deq_scale,
|
||||||
|
bias=quant_bias,
|
||||||
|
output_dtype=original_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
expanding_factor = layer.weight.data.shape[1]
|
||||||
|
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||||
|
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||||
|
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
if self.transpose_weight:
|
||||||
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||||
|
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||||
|
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||||
|
|
||||||
|
|
||||||
|
class NPU_W8A8LinearMethodMTImpl:
|
||||||
|
"""Linear method for NPU W8A8."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.transpose_weight = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_weight(
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
|
params_dict = {}
|
||||||
|
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
||||||
|
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_perchannel_param(
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
params_dict = {}
|
||||||
|
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
||||||
|
if params_dtype == torch.bfloat16:
|
||||||
|
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
|
||||||
|
elif params_dtype == torch.float16:
|
||||||
|
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
|
||||||
|
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||||
|
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
tp_rank: Optional[int] = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
original_dtype = x.dtype
|
||||||
|
if original_dtype != torch.int8:
|
||||||
|
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
||||||
|
|
||||||
|
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||||
|
return ops.quant_matmul(
|
||||||
|
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
layer.aclnn_deq_scale = torch.nn.Parameter(
|
||||||
|
torch_npu.npu_trans_quant_param(layer.deq_scale.npu()).to(device="npu"),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NPU_W8A8LinearMethod(LinearMethodBase):
|
||||||
|
"""Linear method for NPU quantization.
|
||||||
|
|
||||||
|
This class search for specific quantization
|
||||||
|
implementation supported on NPU hardware for linear methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The NPU quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quantization_config: W8A8Int8Config) -> None:
|
||||||
|
self.quantization_config = quantization_config
|
||||||
|
self.quant_method = (
|
||||||
|
NPU_W8A8LinearMethodMTImpl()
|
||||||
|
if useMindIETurbo
|
||||||
|
else NPU_W8A8LinearMethodImpl()
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
) -> None:
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
|
weight_dict = self.quant_method.get_weight(
|
||||||
|
input_size_per_partition, output_size_per_partition, params_dtype
|
||||||
|
)
|
||||||
|
for weight_name, weight_param in weight_dict.items():
|
||||||
|
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||||
|
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
||||||
|
layer.register_parameter(weight_name, param)
|
||||||
|
set_weight_attrs(param, extra_weight_attrs)
|
||||||
|
|
||||||
|
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||||
|
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||||
|
param = PerTensorScaleParameter(
|
||||||
|
data=pertensor_param, weight_loader=weight_loader
|
||||||
|
)
|
||||||
|
# disable warning
|
||||||
|
param.ignore_warning = True
|
||||||
|
layer.register_parameter(pertensor_name, param)
|
||||||
|
|
||||||
|
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||||
|
output_size_per_partition, params_dtype
|
||||||
|
)
|
||||||
|
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||||
|
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
||||||
|
set_weight_attrs(param, {"output_dim": 0})
|
||||||
|
layer.register_parameter(perchannel_name, param)
|
||||||
|
set_weight_attrs(param, extra_weight_attrs)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||||
|
self.quant_method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if isinstance(layer, RowParallelLinear):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||||
|
return self.quant_method.apply(layer, x, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class NPU_W8A8DynamicLinearMethodImpl:
|
||||||
|
"""Linear method for NPU W8A8_DYNAMIC."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.transpose_weight = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_weight(
|
||||||
|
input_size: int, output_size: int, params_dtype: torch.dtype
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_perchannel_param(
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
params_dict = {}
|
||||||
|
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||||
|
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||||
|
return params_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
tp_rank: Optional[int] = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
original_dtype = x.dtype
|
||||||
|
# use ATB quantize
|
||||||
|
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||||
|
return torch_npu.npu_quant_matmul(
|
||||||
|
quant_out,
|
||||||
|
layer.weight,
|
||||||
|
layer.weight_scale,
|
||||||
|
pertoken_scale=dynamic_scale,
|
||||||
|
bias=bias,
|
||||||
|
output_dtype=original_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
if self.transpose_weight:
|
||||||
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||||
|
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||||
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||||
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||||
|
|
||||||
|
|
||||||
|
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
||||||
|
"""Linear method for NPU quantization.
|
||||||
|
|
||||||
|
This class search for specific quantization
|
||||||
|
implementations supported on NPU hardware for linear methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The NPU quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quantization_config: W8A8Int8Config) -> None:
|
||||||
|
self.quantization_config = quantization_config
|
||||||
|
self.quant_method = NPU_W8A8DynamicLinearMethodImpl()
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
) -> None:
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
|
weight_dict = self.quant_method.get_weight(
|
||||||
|
input_size_per_partition, output_size_per_partition, params_dtype
|
||||||
|
)
|
||||||
|
for weight_name, weight_param in weight_dict.items():
|
||||||
|
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||||
|
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
||||||
|
layer.register_parameter(weight_name, param)
|
||||||
|
set_weight_attrs(param, extra_weight_attrs)
|
||||||
|
|
||||||
|
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||||
|
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||||
|
param = PerTensorScaleParameter(
|
||||||
|
data=pertensor_param, weight_loader=weight_loader
|
||||||
|
)
|
||||||
|
# disable warning
|
||||||
|
param.ignore_warning = True
|
||||||
|
layer.register_parameter(pertensor_name, param)
|
||||||
|
|
||||||
|
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||||
|
output_size_per_partition, params_dtype
|
||||||
|
)
|
||||||
|
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||||
|
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
||||||
|
set_weight_attrs(param, {"output_dim": 0})
|
||||||
|
layer.register_parameter(perchannel_name, param)
|
||||||
|
set_weight_attrs(param, extra_weight_attrs)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||||
|
self.quant_method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if isinstance(layer, RowParallelLinear):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
||||||
|
return self.quant_method.apply(layer, x, bias)
|
||||||
|
|
||||||
|
|
||||||
|
class NPU_W8A8MoEMethod:
|
||||||
|
"""MoE method for NPU quantization.
|
||||||
|
|
||||||
|
This class search for specific quantization
|
||||||
|
implementations supported on NPU hardware for moe methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The NPU quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quantization_config: W8A8Int8Config) -> None:
|
||||||
|
self.quantization_config = quantization_config
|
||||||
|
self.quant_method = self
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: List[int],
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
) -> None:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
|
self.num_experts = num_experts
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
||||||
|
)
|
||||||
|
|
||||||
|
# weight
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
# scale
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
# offset
|
||||||
|
w13_weight_offset = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_weight_offset", w13_weight_offset)
|
||||||
|
set_weight_attrs(w13_weight_offset, extra_weight_attrs)
|
||||||
|
w2_weight_offset = torch.nn.Parameter(
|
||||||
|
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w2_weight_offset", w2_weight_offset)
|
||||||
|
set_weight_attrs(w2_weight_offset, extra_weight_attrs)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
layer.w13_weight = Parameter(
|
||||||
|
layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w2_weight = Parameter(
|
||||||
|
layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w13_weight_scale = Parameter(
|
||||||
|
layer.w13_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w2_weight_scale = Parameter(
|
||||||
|
layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w13_weight_offset = Parameter(
|
||||||
|
layer.w13_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w2_weight_offset = Parameter(
|
||||||
|
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
router_logits,
|
||||||
|
top_k,
|
||||||
|
renormalize,
|
||||||
|
use_grouped_topk,
|
||||||
|
topk_group,
|
||||||
|
num_expert_group,
|
||||||
|
num_fused_shared_experts,
|
||||||
|
custom_routing_function,
|
||||||
|
correction_bias,
|
||||||
|
activation,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
routed_scaling_factor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
|
global_num_experts = router_logits.shape[-1]
|
||||||
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
|
if global_num_experts == 256:
|
||||||
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
|
router_logits,
|
||||||
|
k=top_k,
|
||||||
|
bias=correction_bias,
|
||||||
|
k_group=topk_group,
|
||||||
|
group_count=num_expert_group,
|
||||||
|
group_select_mode=1,
|
||||||
|
renorm=0,
|
||||||
|
norm_type=1,
|
||||||
|
routed_scaling_factor=1,
|
||||||
|
eps=float(1e-20),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
torch_native=True,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
)
|
||||||
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
return npu_fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w13=layer.w13_weight,
|
||||||
|
w13_scale=layer.w13_weight_scale,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|||||||
@@ -34,16 +34,18 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
|
||||||
|
|
||||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GB = 1024 * 1024 * 1024
|
GB = 1024 * 1024 * 1024
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
_is_npu = is_npu()
|
||||||
|
if not _is_npu:
|
||||||
|
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
||||||
|
|
||||||
|
|
||||||
class ReqToTokenPool:
|
class ReqToTokenPool:
|
||||||
|
|||||||
@@ -64,10 +64,13 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_device_capability,
|
get_device_capability,
|
||||||
|
is_npu,
|
||||||
is_pin_memory_available,
|
is_pin_memory_available,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_is_npu = is_npu()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
||||||
@@ -127,18 +130,19 @@ def _get_quantization_config(
|
|||||||
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
return None
|
return None
|
||||||
major, minor = get_device_capability()
|
if not _is_npu:
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
|
||||||
if major is not None and minor is not None:
|
if major is not None and minor is not None:
|
||||||
assert 0 <= minor < 10
|
assert 0 <= minor < 10
|
||||||
capability = major * 10 + minor
|
capability = major * 10 + minor
|
||||||
if capability < quant_config.get_min_capability():
|
if capability < quant_config.get_min_capability():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The quantization method {model_config.quantization} "
|
f"The quantization method {model_config.quantization} "
|
||||||
"is not supported for the current GPU. "
|
"is not supported for the current GPU. "
|
||||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||||
f"Current capability: {capability}."
|
f"Current capability: {capability}."
|
||||||
)
|
)
|
||||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||||
if model_config.dtype not in supported_dtypes:
|
if model_config.dtype not in supported_dtypes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -157,6 +161,13 @@ def _initialize_model(
|
|||||||
"""Initialize a model with the given configurations."""
|
"""Initialize a model with the given configurations."""
|
||||||
model_class, _ = get_model_architecture(model_config)
|
model_class, _ = get_model_architecture(model_config)
|
||||||
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
||||||
|
if _is_npu:
|
||||||
|
packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
|
||||||
|
"q_a_proj",
|
||||||
|
"kv_a_proj_with_mqa",
|
||||||
|
]
|
||||||
|
packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
|
||||||
|
packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
|
||||||
quant_config = _get_quantization_config(
|
quant_config = _get_quantization_config(
|
||||||
model_config, load_config, packed_modules_mapping
|
model_config, load_config, packed_modules_mapping
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -575,6 +575,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|||||||
@@ -407,6 +407,8 @@ class QuantMixtralForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
@@ -418,6 +420,8 @@ class QuantMixtralForCausalLM(nn.Module):
|
|||||||
# Skip experts that are not assigned to this worker.
|
# Skip experts that are not assigned to this worker.
|
||||||
if "block_sparse_moe.experts." in name and name not in params_dict:
|
if "block_sparse_moe.experts." in name and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@@ -538,6 +538,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ def get_int_env_var(name: str, default: int = 0) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def support_triton(backend: str) -> bool:
|
def support_triton(backend: str) -> bool:
|
||||||
return backend not in ["torch_native", "intel_amx"]
|
return backend not in ["torch_native", "intel_amx", "ascend"]
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -2782,3 +2782,101 @@ def lru_cache_frozenset(maxsize=128):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def apply_module_patch(target_module, target_function, wrappers):
|
||||||
|
original_module, original_function = parse_module_path(
|
||||||
|
target_module, target_function, False
|
||||||
|
)
|
||||||
|
|
||||||
|
original_function_id = id(original_function)
|
||||||
|
|
||||||
|
candidate = original_function
|
||||||
|
for wrapper in wrappers:
|
||||||
|
candidate = wrapper(candidate)
|
||||||
|
if target_function is not None:
|
||||||
|
setattr(original_module, target_function, candidate)
|
||||||
|
|
||||||
|
for key, value in sys.modules.copy().items():
|
||||||
|
if (
|
||||||
|
target_function is not None
|
||||||
|
and hasattr(value, target_function)
|
||||||
|
and id(getattr(value, target_function)) == original_function_id
|
||||||
|
):
|
||||||
|
setattr(value, target_function, candidate)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_module_path(module_path, function_name, create_dummy):
|
||||||
|
from importlib.machinery import ModuleSpec
|
||||||
|
|
||||||
|
def create_dummy_module(full_path, parent=None):
|
||||||
|
"""Create and register a placeholder module"""
|
||||||
|
dummy = types.ModuleType(full_path)
|
||||||
|
dummy.__file__ = "vllm_ascend.dummy_module.py"
|
||||||
|
dummy.__spec__ = ModuleSpec(full_path, None)
|
||||||
|
sys.modules[full_path] = dummy
|
||||||
|
if parent:
|
||||||
|
setattr(parent, full_path.split(".")[-1], dummy)
|
||||||
|
return dummy
|
||||||
|
|
||||||
|
def create_placeholder_function(func_name):
|
||||||
|
"""Create dummy function that raises when called"""
|
||||||
|
|
||||||
|
def placeholder(*args, **kwargs):
|
||||||
|
raise NotImplementedError(f"Function {func_name} is a placeholder")
|
||||||
|
|
||||||
|
placeholder.__name__ = func_name
|
||||||
|
return placeholder
|
||||||
|
|
||||||
|
modules = module_path.split(".")
|
||||||
|
current_module = None
|
||||||
|
processed_path = []
|
||||||
|
|
||||||
|
for idx, part in enumerate(modules):
|
||||||
|
current_path = ".".join(modules[: idx + 1])
|
||||||
|
parent_path = ".".join(modules[:idx]) if idx > 0 else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_module = importlib.import_module(current_path)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# Handle missing module
|
||||||
|
parent = importlib.import_module(parent_path) if parent_path else None
|
||||||
|
if parent and hasattr(parent, part):
|
||||||
|
# Use existing attribute from parent
|
||||||
|
current_module = getattr(parent, part)
|
||||||
|
# Check for early function resolution
|
||||||
|
if function_name and hasattr(current_module, function_name):
|
||||||
|
return current_module, getattr(current_module, function_name)
|
||||||
|
if function_name and create_dummy:
|
||||||
|
ph_func = create_placeholder_function(function_name)
|
||||||
|
setattr(current_module, function_name, ph_func)
|
||||||
|
return current_module, ph_func
|
||||||
|
if function_name:
|
||||||
|
raise AttributeError(
|
||||||
|
f"Function {function_name} missing in {current_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not create_dummy:
|
||||||
|
raise
|
||||||
|
# Create and register dummy module
|
||||||
|
current_module = create_dummy_module(
|
||||||
|
current_path,
|
||||||
|
parent=(
|
||||||
|
importlib.import_module(parent_path) if parent_path else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_path.append(part)
|
||||||
|
|
||||||
|
# Final function handling
|
||||||
|
final_module = sys.modules[module_path]
|
||||||
|
if function_name is not None:
|
||||||
|
if not hasattr(final_module, function_name):
|
||||||
|
if create_dummy:
|
||||||
|
ph_func = create_placeholder_function(function_name)
|
||||||
|
setattr(final_module, function_name, ph_func)
|
||||||
|
else:
|
||||||
|
setattr(final_module, function_name, None)
|
||||||
|
return final_module, getattr(final_module, function_name)
|
||||||
|
|
||||||
|
return final_module, None
|
||||||
|
|||||||
Reference in New Issue
Block a user