add deepseekv3 and llama4
This commit is contained in:
@@ -425,9 +425,9 @@ class SparseMoeMlp(nn.Module):
|
||||
scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs)
|
||||
|
||||
# token_count: [self.num_experts_per_rank]
|
||||
partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True)
|
||||
zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device)
|
||||
token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count)
|
||||
# Use scatter_add_ instead of torch.unique for MLU graph capture compatibility
|
||||
token_count = torch.zeros(self.num_total_experts, dtype=sorted_expert_id.dtype, device=device)
|
||||
token_count.scatter_add_(0, sorted_expert_id, torch.ones_like(sorted_expert_id))
|
||||
# cusum_token_count: [self.num_experts_per_rank + 1]
|
||||
cusum_token_count = torch.cat(
|
||||
[torch.tensor([0], dtype=token_count.dtype, device=device),
|
||||
|
||||
Reference in New Issue
Block a user