add deepseekv3 and llama4

This commit is contained in:
Chranos
2026-02-11 15:58:34 +08:00
parent e0bd67be53
commit 1f77771852

View File

@@ -425,9 +425,9 @@ class SparseMoeMlp(nn.Module):
scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs) scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs)
# token_count: [self.num_experts_per_rank] # token_count: [self.num_experts_per_rank]
partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True) # Use scatter_add_ instead of torch.unique for MLU graph capture compatibility
zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device) token_count = torch.zeros(self.num_total_experts, dtype=sorted_expert_id.dtype, device=device)
token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count) token_count.scatter_add_(0, sorted_expert_id, torch.ones_like(sorted_expert_id))
# cusum_token_count: [self.num_experts_per_rank + 1] # cusum_token_count: [self.num_experts_per_rank + 1]
cusum_token_count = torch.cat( cusum_token_count = torch.cat(
[torch.tensor([0], dtype=token_count.dtype, device=device), [torch.tensor([0], dtype=token_count.dtype, device=device),