[moe] optim: reduce memory consumption in fused_moe (#3692)
This commit is contained in:
@@ -933,20 +933,21 @@ def fused_experts_impl(
|
|||||||
|
|
||||||
config = get_config_func(M)
|
config = get_config_func(M)
|
||||||
|
|
||||||
intermediate_cache1 = torch.empty(
|
cache = torch.empty(
|
||||||
(M, topk_ids.shape[1], N),
|
M * topk_ids.shape[1] * max(N, w2.shape[1]),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
|
intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view(
|
||||||
|
(M, topk_ids.shape[1], N),
|
||||||
|
)
|
||||||
intermediate_cache2 = torch.empty(
|
intermediate_cache2 = torch.empty(
|
||||||
(M * topk_ids.shape[1], N // 2),
|
(M * topk_ids.shape[1], N // 2),
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
intermediate_cache3 = torch.empty(
|
intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view(
|
||||||
(M, topk_ids.shape[1], w2.shape[1]),
|
(M, topk_ids.shape[1], w2.shape[1]),
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||||
|
|||||||
Reference in New Issue
Block a user