* [Feature] support deepseek v3/r1/v3.2 * fix gpt_oss * update readme * update readme --------- Co-authored-by: hanhaowen <hanhaowen@baidu.com>
169 lines
6.2 KiB
Python
169 lines
6.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import enum
|
|
from enum import Enum
|
|
from typing import Callable, Optional, Union
|
|
|
|
import torch
|
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
|
|
|
|
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
|
|
"""modify scale -> abs max"""
|
|
layer.w13_weight = torch.nn.Parameter(layer.w13_weight, requires_grad=False)
|
|
layer.w2_weight = torch.nn.Parameter(layer.w2_weight, requires_grad=False)
|
|
layer.w13_weight_scale = torch.nn.Parameter(
|
|
layer.w13_weight_scale.data * 127, requires_grad=False
|
|
)
|
|
layer.w2_weight_scale = torch.nn.Parameter(
|
|
layer.w2_weight_scale.data * 127, requires_grad=False
|
|
)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
klx_process_weights_after_loading(layer)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
enable_eplb: bool = False,
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
logical_replica_count: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
hidden_states = x
|
|
global_num_experts, up_gate_size, _ = layer.w13_weight.shape
|
|
M, N = hidden_states.shape
|
|
hidden_dim = layer.w2_weight.shape[1]
|
|
normed_score = torch.empty(M,
|
|
top_k,
|
|
dtype=torch.float32,
|
|
device=hidden_states.device)
|
|
topk_ids = torch.empty(M,
|
|
top_k,
|
|
dtype=torch.int32,
|
|
device=hidden_states.device)
|
|
num_blocks = 12
|
|
block_statistic = torch.zeros(
|
|
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
|
)
|
|
|
|
router_logits = router_logits.float()
|
|
if scoring_func == "softmax":
|
|
torch.ops._C.moe_softmax_topk_norm(
|
|
x=router_logits,
|
|
normed_score=normed_score,
|
|
topk_index=topk_ids,
|
|
block_statistic=None,
|
|
stable=True)
|
|
elif scoring_func == "sigmoid":
|
|
torch.ops._C.moe_sigmoid_group_topk_norm(
|
|
x=router_logits,
|
|
norm_score=normed_score,
|
|
topk_index=topk_ids,
|
|
block_static=block_statistic,
|
|
bias=e_score_correction_bias,
|
|
n_group=num_expert_group,
|
|
topk_group=topk_group,
|
|
scale=routed_scaling_factor,
|
|
)
|
|
|
|
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
|
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
|
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
|
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
|
|
|
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
|
|
|
torch.ops._C.moe_pre_sorted(
|
|
x=hidden_states,
|
|
topk_index=topk_ids,
|
|
block_statistic=block_statistic,
|
|
moe_expand=moe_expand,
|
|
moe_index=sorted_tokens_idx,
|
|
expert_m=expert_m,
|
|
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
|
|
|
y = torch.empty(M,top_k,
|
|
layer.w13_weight.shape[1],
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device)
|
|
|
|
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
|
|
|
x_shape = moe_expand.shape
|
|
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
|
|
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
|
|
torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)
|
|
|
|
torch.ops._C.moe_fc(
|
|
x=x_q,
|
|
x_perchannel_max=x_scale,
|
|
weight=layer.w13_weight,
|
|
w_perchannel_max=layer.w13_weight_scale,
|
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
|
sorted_tokens_idx=sorted_tokens_idx,
|
|
moe_topk=top_k,
|
|
y=y,
|
|
topk_ids=topk_ids,
|
|
# sort_mode=False,
|
|
act=None)
|
|
|
|
d = y.shape[-1] // 2
|
|
output_shape = (y.shape[:-1] + (d, ))
|
|
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
|
torch.ops._C.silu_and_mul(out1, y)
|
|
|
|
out = torch.empty(M,top_k,
|
|
layer.w2_weight.shape[1],
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device)
|
|
|
|
out1 = out1.reshape(-1, out1.shape[-1])
|
|
x_shape = out1.shape
|
|
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
|
|
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
|
|
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
|
|
|
|
torch.ops._C.moe_fc(
|
|
x=x_q,
|
|
x_perchannel_max=x_scale,
|
|
weight=layer.w2_weight,
|
|
w_perchannel_max=layer.w2_weight_scale,
|
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
|
sorted_tokens_idx=sorted_tokens_idx,
|
|
moe_topk=top_k,
|
|
y=out,
|
|
topk_ids=topk_ids,
|
|
# sort_mode=False,
|
|
act=None)
|
|
|
|
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
|
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
|
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
|
|
|
torch.ops._C.moe_post(
|
|
x=out,
|
|
moe_index=sorted_tokens_idx,
|
|
normed_scale=normed_score,
|
|
dequant_scale=dequant_scale,
|
|
y=output
|
|
)
|
|
return output
|
|
|
|
CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = process_weights_after_loading
|
|
CompressedTensorsW8A8Int8MoEMethod.apply = apply |