diff --git a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp index 90b3c45e..8b16f0b9 100644 --- a/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp +++ b/csrc/dispatch_ffn_combine/op_host/dispatch_ffn_combine_tiling.cpp @@ -26,6 +26,8 @@ using namespace AscendC; using namespace ge; +#define HCCL_BUFFSIZE "HCCL_BUFFSIZE" + namespace { // 1. Constant definitions const char *K_INNER_DEBUG = "DispatchFFNCombine Tiling Debug"; @@ -42,6 +44,7 @@ namespace { constexpr uint32_t EXPERTID_INDEX = 3; constexpr uint32_t BLOCK_NUM = 20; constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; + constexpr uint64_t MB_SIZE = 1024 * 1024UL; } namespace optiling { @@ -54,6 +57,30 @@ static int32_t CeilDev(int32_t num, int32_t div) return (num + div - 1) / div; } +static uint64_t GetMaxWindowSize() +{ + uint16_t defaultWindowSize = 200; + const char* hccl_buffsize_env = getenv(HCCL_BUFFSIZE); + if (hccl_buffsize_env != nullptr) { + try { + std::string envStr(hccl_buffsize_env); + unsigned long val = std::stoul(envStr); + if (val <= std::numeric_limits::max()) { + defaultWindowSize = static_cast(val); + } else { + OP_LOGW(K_INNER_DEBUG, "HCCL_BUFFSIZE value %lu is out of range, using default.", val); + } + } catch (const std::exception& e) { + OP_LOGE(K_INNER_DEBUG, "Exception encountered when parsing env HCCL_BUFFSIZE: %s", e.what()); + } + } else { + OP_LOGD(K_INNER_DEBUG, "Env HCCL_BUFFSIZE not set"); + } + const uint64_t maxWindowSize = static_cast(defaultWindowSize) * MB_SIZE; + OP_LOGD(K_INNER_DEBUG, "Get maxWindowSize is %lu", maxWindowSize); + return maxWindowSize; +} + // Parse and validate rankId, group, worldSize, and isTransB attributes static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info) { @@ -232,6 +259,14 @@ static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *con tilingData->cocTiling.moeInitRoutingQuantV2TilingData.gatherOutComputeParamsOp = moeInitRoutingQuantV2TilingBase.quantTilingData.gatherOutComputeParamsOp; tilingData->cocTiling.initRoutingQuantTilingKey = initRoutingQuantTilingKey; + uint64_t maxWindowSize = GetMaxWindowSize(); + uint64_t actualSize = static_cast(info.M) * info.topK * info.K * sizeof(int8_t) * 3 + 10 * MB_SIZE ; + OP_TILING_CHECK((actualSize > maxWindowSize), + OP_LOGE(nodeName, "HCCL_BUFFSIZE is too SMALL, m = %lu, k = %lu, topK = %lu" + " expected HCCL_BUFFSIZE is ((m * k * topK * sizeof(int8_t)) * 3 + 3MB)= %luMB, HCCL_BUFFSIZE=%luMB.", + info.M, info.K, info.topK, (actualSize + MB_SIZE - 1) / MB_SIZE, maxWindowSize / MB_SIZE), + return ge::GRAPH_FAILED); + // 4. workspace size_t *workSpaces = context->GetWorkspaceSizes(1); OP_TILING_CHECK(workSpaces == nullptr, OP_LOGE(nodeName, "workSpaces is nullptr."),