[EPLB][Bugfix] EPLB support fp/bf16 (#5531)

### What this PR does / why we need it?
EPLB support dtype of fp/bf16.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
w8a8_dynamic Baseline:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 86.67 |

w8a8_dynamic eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 86.67 |

The fp16 conversation is normal.
The fp16 test is in progress.

Baseline fp16
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 86.67 |

eplb fp16
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 83.33 |

- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
LI SHENGYONG
2026-01-26 14:28:16 +08:00
committed by GitHub
parent 52d4acfa51
commit 611e223b7d
4 changed files with 67 additions and 118 deletions

View File

@@ -31,46 +31,19 @@ class VllmEplbAdaptor(EplbAdaptor):
self.model = model
self.rank_id = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_dict = dict(self.model.named_parameters())
self.num_dense_layers = getattr(self.model.config, "first_k_dense_replace", 0)
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
for i in range(self.num_dense_layers, self.model.config.num_hidden_layers):
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = self.model.model.layers[
i
].mlp.experts.w13_weight_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = self.model.model.layers[
i
].mlp.experts.w2_weight_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = (
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
)
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = (
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
)
# TODO: init self.expert_weight_names depending on different model types.
# Only deepseek v3 w8a8 and qwen3-moe is supported here
if self.model.quant_config is not None:
self.expert_weight_names = [
"w13_weight_list",
"w2_weight_list",
"w13_weight_scale_fp32_list",
"w13_weight_offset",
"w2_weight_scale_list",
"w2_weight_offset",
]
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]
self.expert_map_per_layer_cpu = dict() # copy of expert map on CPU to avoid device synchronize frequently
num_buffer_tensor = self.model.model.layers[-1].mlp.experts.local_num_experts
self.buffer_tensor_list: list[list[Any]] = [[] for _ in range(num_buffer_tensor)]
self.init_buffer_tensor(num_buffer_tensor)
self.num_local_experts = self.model.model.layers[-1].mlp.experts.local_num_experts
self.expert_param_per_layer = dict()
self.init_expert_param_per_layer()
num_buffer_tensor = self.num_local_experts
self.buffer_tensor_list: list[list[Any]] = [[] for _ in range(num_buffer_tensor)]
self.init_buffer_tensor(num_buffer_tensor)
self.log2phy_map_per_layer = dict()
for layer_idx in range(self.num_moe_layers):
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = self.model.get_log2phy_map(
@@ -81,38 +54,34 @@ class VllmEplbAdaptor(EplbAdaptor):
for buffer_id in range(num_buffer_tensor):
for name in self.expert_weight_names:
complete_name = "model.layers." + str(self.num_dense_layers) + ".mlp.experts." + name
if name in ["w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", "w2_weight_scale_list"]:
expert_tensor = self.param_dict[complete_name][0]
expert_tensor = expert_tensor.clone()
else:
expert_tensor = self.param_dict[complete_name][0].data[0]
expert_tensor = self.param_dict[complete_name][0]
buffer_tensor = torch.empty_like(expert_tensor)
self.buffer_tensor_list[buffer_id].append(buffer_tensor)
def init_expert_param_per_layer(self):
key = f"model.layers.{self.num_dense_layers}.mlp.experts.{self.expert_weight_names[0]}"
num_local_expert = len(self.param_dict[key])
for moe_layer_id in range(self.num_moe_layers):
layer_idx = self.num_dense_layers + moe_layer_id
self.param_dict = dict()
if self.model.quant_config is not None:
self.expert_weight_names = [
"w13_weight_list",
"w2_weight_list",
"w13_weight_scale_fp32_list",
"w2_weight_scale_list",
]
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]
for layer_idx in range(self.num_dense_layers, self.model.config.num_hidden_layers):
self.expert_param_per_layer[layer_idx] = list()
for local_expert_id in range(num_local_expert):
for name in self.expert_weight_names:
param_key = f"model.layers.{layer_idx}.mlp.experts.{name}"
param_value = getattr(self.model.model.layers[layer_idx].mlp.experts, name)
self.param_dict[param_key] = param_value
for local_expert_id in range(self.num_local_experts):
per_expert_param = list()
for name in self.expert_weight_names:
if name in [
"w13_weight_list",
"w2_weight_list",
"w13_weight_scale_fp32_list",
"w2_weight_scale_list",
]:
per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][local_expert_id]
)
else:
per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][0].data[
local_expert_id
]
)
per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) + ".mlp.experts." + name][local_expert_id]
)
self.expert_param_per_layer[layer_idx].append(per_expert_param)
def get_rank_expert_workload(self) -> torch.Tensor: