minor: update header and use pytest (#3054)
This commit is contained in:
@@ -19,7 +19,7 @@ clean:
|
|||||||
@rm -rf build dist *.egg-info
|
@rm -rf build dist *.egg-info
|
||||||
|
|
||||||
test:
|
test:
|
||||||
@find tests -name "test_*.py" | xargs -n 1 python3
|
@find tests -name "test_*.py" | xargs -n 1 python3 && pytest tests/test_norm.py && pytest tests/test_activation.py
|
||||||
|
|
||||||
format:
|
format:
|
||||||
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
|
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
|
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
|
||||||
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
|
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
|
||||||
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
|
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
|
||||||
#include "utils.hpp"
|
#include "utils.h"
|
||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
#include <THC/THCAtomics.cuh>
|
#include <THC/THCAtomics.cuh>
|
||||||
|
|
||||||
#include "utils.hpp"
|
#include "utils.h"
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#include <THC/THCAtomics.cuh>
|
#include <THC/THCAtomics.cuh>
|
||||||
|
|
||||||
#include "utils.hpp"
|
#include "utils.h"
|
||||||
#include "vectorization.cuh"
|
#include "vectorization.cuh"
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "utils.hpp"
|
#include "utils.h"
|
||||||
|
|
||||||
// trt_reduce
|
// trt_reduce
|
||||||
using fptr_t = int64_t;
|
using fptr_t = int64_t;
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include "utils.hpp"
|
#include "utils.h"
|
||||||
|
|
||||||
namespace trt_llm {
|
namespace trt_llm {
|
||||||
constexpr size_t WARP_SIZE = 32;
|
constexpr size_t WARP_SIZE = 32;
|
||||||
|
|||||||
Reference in New Issue
Block a user