[moe] optim: reduce memory consumption in fused_moe (#3692)
This commit is contained in:
@@ -933,20 +933,21 @@ def fused_experts_impl(
|
||||
|
||||
config = get_config_func(M)
|
||||
|
||||
intermediate_cache1 = torch.empty(
|
||||
(M, topk_ids.shape[1], N),
|
||||
cache = torch.empty(
|
||||
M * topk_ids.shape[1] * max(N, w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view(
|
||||
(M, topk_ids.shape[1], N),
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache3 = torch.empty(
|
||||
intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view(
|
||||
(M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
|
||||
Reference in New Issue
Block a user