fix: use torch.sum for compatible (#2161)
This commit is contained in:
@@ -766,9 +766,10 @@ def fused_experts_impl(
|
|||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
)
|
)
|
||||||
|
|
||||||
ops.moe_sum(
|
torch.sum(
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
dim=1,
|
||||||
|
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
)
|
)
|
||||||
return out_hidden_states
|
return out_hidden_states
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user