Support Blackwell Block Scale FP8 Gemm (#4278)

This commit is contained in:
Elfie Guo
2025-03-12 14:17:11 -07:00
committed by GitHub
parent 10b544ae9b
commit 7c86671131
3 changed files with 207 additions and 2 deletions

View File

@@ -141,3 +141,23 @@ __device__ __forceinline__ float blockReduceMax(float max_value) {
return max_value;
}
#endif
// Pads to a multiple of `alignment` rows.
inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) {
int64_t rows = tensor.size(0);
int64_t cols = tensor.size(1);
int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size
if (pad_rows == 0) {
return tensor; // Already aligned
}
torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options());
torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows
// Ensure column-major layout
if (is_column_major) {
return tensor_padded.t().contiguous().t();
}
return tensor_padded;
}