Support Blackwell Block Scale FP8 Gemm (#4278)
This commit is contained in:
@@ -141,3 +141,23 @@ __device__ __forceinline__ float blockReduceMax(float max_value) {
|
||||
return max_value;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Pads to a multiple of `alignment` rows.
|
||||
inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) {
|
||||
int64_t rows = tensor.size(0);
|
||||
int64_t cols = tensor.size(1);
|
||||
int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size
|
||||
|
||||
if (pad_rows == 0) {
|
||||
return tensor; // Already aligned
|
||||
}
|
||||
|
||||
torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options());
|
||||
torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows
|
||||
|
||||
// Ensure column-major layout
|
||||
if (is_column_major) {
|
||||
return tensor_padded.t().contiguous().t();
|
||||
}
|
||||
return tensor_padded;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user