add deepseekv3 and llama4
This commit is contained in:
@@ -425,9 +425,9 @@ class SparseMoeMlp(nn.Module):
|
|||||||
scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs)
|
scatter_idx=torch.zeros((indices.numel(),), dtype=seqs.dtype, device=seqs.device).scatter(0, indices, seqs)
|
||||||
|
|
||||||
# token_count: [self.num_experts_per_rank]
|
# token_count: [self.num_experts_per_rank]
|
||||||
partial_token_index, partial_token_count = sorted_expert_id.unique(sorted=True, return_counts=True)
|
# Use scatter_add_ instead of torch.unique for MLU graph capture compatibility
|
||||||
zero_token_count = torch.zeros(self.num_total_experts, dtype=partial_token_count.dtype, device=device)
|
token_count = torch.zeros(self.num_total_experts, dtype=sorted_expert_id.dtype, device=device)
|
||||||
token_count = zero_token_count.scatter(dim=0, index=partial_token_index, src=partial_token_count)
|
token_count.scatter_add_(0, sorted_expert_id, torch.ones_like(sorted_expert_id))
|
||||||
# cusum_token_count: [self.num_experts_per_rank + 1]
|
# cusum_token_count: [self.num_experts_per_rank + 1]
|
||||||
cusum_token_count = torch.cat(
|
cusum_token_count = torch.cat(
|
||||||
[torch.tensor([0], dtype=token_count.dtype, device=device),
|
[torch.tensor([0], dtype=token_count.dtype, device=device),
|
||||||
|
|||||||
Reference in New Issue
Block a user