ggml: add ggml_can_fuse_subgraph (#16662)
* ggml: add ggml_can_fuse_subgraph * ggml-cuda: use ggml_can_fuse_subgraph for topk-moe * format * 1. remove inputs from signature as they are transient nodes 2. add check for views: view_src should be part of the subgraph * - combine check into one loop - check all view_src parents - other minor review comments * remove redudant if test * - rename and other minor review comments * add assert about count < 32
This commit is contained in:
@@ -6964,6 +6964,78 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
||||
GGML_LOG_INFO("========================================\n");
|
||||
}
|
||||
|
||||
static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,
|
||||
const int * idxs,
|
||||
int count,
|
||||
const struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(cgraph && idxs);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
const int node_idx = idxs[i];
|
||||
|
||||
if (node_idx >= cgraph->n_nodes) {
|
||||
return -1;
|
||||
}
|
||||
if (cgraph->nodes[node_idx] == tensor) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
|
||||
const int * node_idxs,
|
||||
int count,
|
||||
const enum ggml_op * ops,
|
||||
const int * outputs,
|
||||
int num_outputs) {
|
||||
GGML_ASSERT(outputs && num_outputs > 0);
|
||||
|
||||
for (int i = 0; i < count; ++i) {
|
||||
if (node_idxs[i] >= cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
|
||||
|
||||
if (node->op != ops[i]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int subgraph_uses = 0;
|
||||
for (int j = i + 1; j < count; ++j) {
|
||||
const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
|
||||
if (other_node->src[src_idx] == node) {
|
||||
subgraph_uses++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
|
||||
struct ggml_tensor * view_src = node->view_src;
|
||||
while (view_src) {
|
||||
if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
|
||||
return false;
|
||||
}
|
||||
view_src = view_src->view_src;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// check if node is part of the graph
|
||||
static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
|
||||
if (cgraph == NULL) {
|
||||
|
||||
Reference in New Issue
Block a user