[ROCm] Enable MTP (NextN) on AMD GPU (#4631)

This commit is contained in:
Alex Sun
2025-03-24 13:58:05 +08:00
committed by GitHub
parent 93cf7fc5cd
commit af6535e7aa
7 changed files with 43 additions and 4 deletions

View File

@@ -17,7 +17,11 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#endif
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]