init v0.11.0rc0
This commit is contained in:
@@ -24,10 +24,10 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
@@ -133,11 +133,14 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 256)
|
||||
# NOTE: the weights are quantized from bf16 to int4 through a per-channel quantization process
|
||||
self.is_per_channel_weight = self.group_size == 0
|
||||
quant_version = vllm_config.quant_config.quant_description.get(
|
||||
"version", "0")
|
||||
# NOTE: new quantize weights: 2 int4 pack into int8
|
||||
self.new_quant_version = quant_version == "1.0.0"
|
||||
self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size
|
||||
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
||||
if self.new_quant_version and self.tp_size > 16:
|
||||
raise ValueError(
|
||||
"The current weight does not support moe part tp>16.")
|
||||
@@ -182,44 +185,44 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=params_dtype)
|
||||
dtype=torch.float32)
|
||||
if not self.is_per_channel_weight:
|
||||
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.float32)
|
||||
|
||||
if self.new_quant_version:
|
||||
param_dict["w13_scale_bias"] = torch.empty(
|
||||
@@ -275,14 +278,6 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
shared_gate_up, shared_dequant_scale = None, None
|
||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||
share_up_out, _ = shared_experts.gate_up_proj(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||
0], share_up_out[1]
|
||||
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
@@ -291,27 +286,36 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
return unified_fused_experts_eager(
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
use_int4_w4a8=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
mc2_mask=kwargs.get("mc2_mask", None),
|
||||
with_quant=True)
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
scale = scale.transpose(1, 2).contiguous()
|
||||
if self.is_per_channel_weight:
|
||||
scale_np = scale.cpu().numpy()
|
||||
scale_np.dtype = np.uint32
|
||||
scale_uint64_tensor = torch.from_numpy(scale_np.astype(
|
||||
np.int64)).npu()
|
||||
return scale_uint64_tensor, None
|
||||
per_group_scale = per_group_scale.transpose(1, 2).contiguous()
|
||||
group_num, k, n = weight.shape
|
||||
# the weight of the new version is reduced by half by pack n, so it needs to be restored
|
||||
if self.new_quant_version:
|
||||
@@ -354,13 +358,10 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
|
||||
def pack_to_int32(self, weight: torch.Tensor):
|
||||
if self.new_quant_version:
|
||||
group_num, k, n = weight.shape
|
||||
assert n % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
packed_n = n // 4
|
||||
# pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4
|
||||
packed_weight = torch.from_numpy(
|
||||
np.frombuffer(weight.cpu().numpy().tobytes(), dtype=np.int32))
|
||||
return packed_weight.reshape(group_num, k, packed_n).npu()
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4"
|
||||
return weight.view(torch.int32).contiguous()
|
||||
else:
|
||||
return torch_npu.npu_quantize(weight.to(torch.float32),
|
||||
torch.tensor([1.]).npu(), None,
|
||||
@@ -372,23 +373,29 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
layer.w13_weight_scale_second.data, w13_bias = self.process_scale(
|
||||
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
|
||||
layer, "w13_weight_scale_second") else None
|
||||
w2_weight_scale_second = layer.w2_weight_scale_second.data if hasattr(
|
||||
layer, "w2_weight_scale_second") else None
|
||||
layer.w13_weight_scale.data, w13_bias = self.process_scale(
|
||||
layer.w13_weight, layer.w13_weight_scale.data,
|
||||
layer.w13_weight_scale_second.data)
|
||||
layer.w2_weight_scale_second.data, w2_bias = self.process_scale(
|
||||
w13_weight_scale_second)
|
||||
layer.w2_weight_scale.data, w2_bias = self.process_scale(
|
||||
layer.w2_weight, layer.w2_weight_scale.data,
|
||||
layer.w2_weight_scale_second.data)
|
||||
w2_weight_scale_second)
|
||||
if hasattr(layer, "w13_weight_scale_second"):
|
||||
# scale_second is no longer used, release this part of the memory
|
||||
del layer.w13_weight_scale_second
|
||||
del layer.w2_weight_scale_second
|
||||
del layer.w13_weight_offset_second
|
||||
del layer.w2_weight_offset_second
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
|
||||
Reference in New Issue
Block a user