fix: use torch.sum for compatible (#2161)
This commit is contained in:
@@ -766,9 +766,10 @@ def fused_experts_impl(
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
)
|
||||
|
||||
ops.moe_sum(
|
||||
torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
dim=1,
|
||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user