add llama4

This commit is contained in:
Chranos
2026-02-11 16:08:37 +08:00
parent 16d41a8fc1
commit 7b4f7d74c3

View File

@@ -284,34 +284,28 @@ class SparseMoeMlp(nn.Module):
def forward_experts_nofused(self, hidden_states, expert_logits):
hidden_states_shape = hidden_states.shape
# Dense approach: each expert processes ALL tokens, then mask by routing
# weights. This avoids data-dependent control flow (variable-size slicing,
# conditional branches, torch.unique, torch.tensor creation) that is
# incompatible with MLU graph capture.
num_tokens, hidden_size = hidden_states.shape
topk_values, topk_indices = self.topk_softmax(expert_logits)
expand_gather_idx, scatter_idx, expand_token_count, cusum_token_count = self.generate_gather_idx(
topk_indices)
# no expert is routed, then expand_gather_idx, expand_scatter_idx has no item,
# expand_token_count and expand_cusum_token_count has item but the value is all zero
# so this rank should only return final_hidden_states with zero value
if expand_gather_idx.numel() == 0:
final_hidden_states = torch.zeros_like(hidden_states,
dtype=hidden_states.dtype,
device=hidden_states.device)
return final_hidden_states
expand_hidden_states = self.expand_input(hidden_states, expand_gather_idx)
final_hidden_states = torch.zeros(
num_tokens, hidden_size,
dtype=hidden_states.dtype, device=hidden_states.device)
expand_output_list = []
expand_cusum_token_count = cusum_token_count[self.start_expert_id:self.end_expert_id +
1] - cusum_token_count[self.start_expert_id]
for expert_idx, num_tokens_per_expert in enumerate(expand_token_count):
if num_tokens_per_expert > 0:
expert_hidden_states = expand_hidden_states[
expand_cusum_token_count[expert_idx]:expand_cusum_token_count[expert_idx + 1]]
expert_output = self.experts[expert_idx](expert_hidden_states)
expert_output = expert_output[0] if isinstance(expert_output, (tuple, list)) else expert_output
expand_output_list.append(expert_output)
expand_output = torch.cat(expand_output_list, dim=0)
final_hidden_states = self.combine_moe(expand_output, scatter_idx, cusum_token_count, hidden_states_shape,
topk_values)
for expert_idx in range(self.num_experts_per_rank):
global_expert_idx = self.start_expert_id + expert_idx
expert_output = self.experts[expert_idx](hidden_states)
expert_output = expert_output[0] if isinstance(
expert_output, (tuple, list)) else expert_output
# Routing weight per token for this expert
expert_mask = (topk_indices == global_expert_idx).to(topk_values.dtype)
expert_weights = (topk_values * expert_mask).sum(dim=-1, keepdim=True)
final_hidden_states = final_hidden_states + expert_output * expert_weights
return final_hidden_states