[ROCm] Enable MTP (NextN) on AMD GPU (#4631)
This commit is contained in:
@@ -17,7 +17,11 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include "pytorch_extension_utils.h"
|
||||
#else
|
||||
#include "pytorch_extension_utils_rocm.h"
|
||||
#endif
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
// selected_index [bs, draft_token_num - 1]
|
||||
|
||||
Reference in New Issue
Block a user