Update comment doc (#4731)
### What this PR does / why we need it?
Translate remaining Chinese comments in the `dispatch_ffn_combine` code
to English and update the installation guide to remind users to
initialize submodules when building from source.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: mojave2 <chenchen145@huawei.com>
Signed-off-by: Chen Chen <0109chenchen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -22,19 +22,22 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* 算子功能:实现分布式MoE从InitRouting到Unpermute全部算子的融合
|
||||
* @brief aclnnDispatchFFNCombine的第一段接口,根据具体的计算流程,计算workspace大小。
|
||||
* Operator function: fuse all distributed MoE ops from InitRouting through Unpermute.
|
||||
* @brief First-stage interface of aclnnDispatchFFNCombine that calculates workspace size based on the specific compute flow.
|
||||
* @domain aclnn_ops_infer
|
||||
* @param [in] a: matmul左矩阵,数据类型支持:float16, bf16。
|
||||
* @param [in] b: matmul右矩阵,数据类型支持:float16, bf16。
|
||||
* @param [in] bias: 偏置,数据类型支持:float16, bf16。
|
||||
* @param [in] group: 标识通信域名称的字符串。
|
||||
* @param [in] worldsize: 通信域size,支持2/4/8卡。
|
||||
* @param [in] epRankId: ep本卡Id。取值范围[0, worldSize),各卡的rankId不能重复
|
||||
* @param [out] c: 计算+通信的结果,数据类型:同输入。
|
||||
* @param [out] workspaceSize: 返回需要在npu device侧申请的workspace大小。
|
||||
* @param [out] executor: 返回op执行器,包含了算子计算流程。
|
||||
* @return aclnnStatus: 返回状态码
|
||||
* @param [in] x: The input tensor.
|
||||
* @param [in] weight1: The first weight tensor.
|
||||
* @param [in] weight2: The second weight tensor.
|
||||
* @param [in] expertId: The expert ID tensor.
|
||||
* @param [in] scale1: The first scale tensor.
|
||||
* @param [in] scale2: The second scale tensor.
|
||||
* @param [in] probs: The probabilities tensor.
|
||||
* @param [in] group: string identifying the communication domain name.
|
||||
* @param [in] maxOutputSize: The maximum output size.
|
||||
* @param [out] out: result of computation + communication; same dtype as input.
|
||||
* @param [out] workspaceSize: workspace size to allocate on the NPU device side.
|
||||
* @param [out] executor: op executor containing the operator compute flow.
|
||||
* @return aclnnStatus: status code.
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
||||
@@ -44,12 +47,12 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWor
|
||||
uint64_t* workspaceSize, aclOpExecutor** executor);
|
||||
|
||||
/**
|
||||
* @brief aclnnDispatchGmmCombine的第二段接口,用于执行计算。
|
||||
* @param [in] workspace: 在npu device侧申请的workspace内存起址。
|
||||
* @param [in] workspace_size: 在npu device侧申请的workspace大小,由第一段接口aclnnDispatchFFNCombineGetWorkspaceSize获取。
|
||||
* @param [in] exector: op执行器,包含了算子计算流程。
|
||||
* @param [in] stream: acl stream流。
|
||||
* @return aclnnStatus: 返回状态码
|
||||
* @brief Second-stage interface of aclnnDispatchFFNCombine to execute computation.
|
||||
* @param [in] workspace: workspace memory address allocated on the NPU device side.
|
||||
* @param [in] workspace_size: workspace size allocated on the NPU device side, obtained from aclnnDispatchFFNCombineGetWorkspaceSize.
|
||||
* @param [in] executor: op executor containing the operator compute flow.
|
||||
* @param [in] stream: acl stream.
|
||||
* @return aclnnStatus: status code.
|
||||
*/
|
||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombine(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
|
||||
aclrtStream stream);
|
||||
@@ -58,4 +61,4 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombine(void*
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // OP_API_INC_GMM_ALLTOALLV_
|
||||
#endif // OP_API_INC_DISPATCH_FFN_COMBINE_
|
||||
|
||||
@@ -56,7 +56,7 @@ class DispatchFFNCombine : public OpDef {
|
||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||
|
||||
// 输出
|
||||
// Output
|
||||
this->Output("out")
|
||||
.ParamType(REQUIRED)
|
||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
|
||||
|
||||
@@ -27,7 +27,7 @@ using namespace AscendC;
|
||||
using namespace ge;
|
||||
|
||||
namespace {
|
||||
// 1. 常量定义
|
||||
// 1. Constant definitions
|
||||
const char *K_INNER_DEBUG = "DispatchFFNCombine Tiling Debug";
|
||||
constexpr uint32_t ATTR_GROUP_INDEX = 0;
|
||||
constexpr uint32_t ATTR_MAX_OUTPUT_SIZE_INDEX = 1;
|
||||
@@ -54,13 +54,13 @@ static int32_t CeilDev(int32_t num, int32_t div)
|
||||
return (num + div - 1) / div;
|
||||
}
|
||||
|
||||
// 解析并校验 rankId, group, worldSize, isTransB 属性值
|
||||
// Parse and validate rankId, group, worldSize, and isTransB attributes
|
||||
static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
|
||||
{
|
||||
auto attrs = context->GetAttrs();
|
||||
OP_TILING_CHECK(attrs == nullptr, OP_LOGE(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED);
|
||||
|
||||
// todo:Attr相关tilingdata的设置、校验、打印
|
||||
// TODO: set, validate, and print tiling data related to attributes
|
||||
auto groupPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_INDEX));
|
||||
auto maxOutputSizePtr = attrs->GetAttrPointer<int>(ATTR_MAX_OUTPUT_SIZE_INDEX);
|
||||
auto is_trans_b = attrs->GetAttrPointer<bool>(ATTR_IS_TRANS_B);
|
||||
@@ -87,7 +87,7 @@ static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingConte
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// 提取输入张量 A 和 B 的形状,计算出 M、K、N 值
|
||||
// Extract shapes of input tensors A and B to compute M, K, N
|
||||
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
|
||||
{
|
||||
const char *nodeName = context->GetNodeName();
|
||||
@@ -116,7 +116,7 @@ static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingCont
|
||||
return ge::GRAPH_SUCCESS;
|
||||
}
|
||||
|
||||
// 获取当前芯片平台的 AI Core 数目、UB 容量等硬件信息。
|
||||
// Get hardware info such as AI Core count and UB capacity for the current chip platform.
|
||||
static ge::graphStatus DispatchFFNCombineGetPlatformInfoAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
|
||||
{
|
||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||
@@ -146,9 +146,9 @@ void SetTilingData(CoCTiling &cocTilingData, DispatchFFNCombineInfo &info)
|
||||
cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 / 2;
|
||||
}
|
||||
|
||||
// 主调度函数:
|
||||
// 获取 tilingData ➝ 检查 Attr ➝ 检查 Shape ➝ 获取平台信息
|
||||
// ➝ 调用 SetTilingData(根据rank数目) ➝ 设置 blockDim ➝ 设置 tilingKey ➝ 设置 workspace ➝ 配置通信参数
|
||||
// Main scheduling function:
|
||||
// Get tilingData ➝ check Attr ➝ check Shape ➝ get platform info
|
||||
// ➝ call SetTilingData (based on rank count) ➝ set blockDim ➝ set tilingKey ➝ set workspace ➝ configure communication parameters
|
||||
|
||||
static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *context)
|
||||
{
|
||||
@@ -262,4 +262,4 @@ ge::graphStatus TilingParseForDispatchFFNCombine(gert::TilingParseContext *conte
|
||||
IMPL_OP_OPTILING(DispatchFFNCombine)
|
||||
.Tiling(DispatchFFNCombineTilingFunc)
|
||||
.TilingParse<DispatchFFNCombineCompileInfo>(TilingParseForDispatchFFNCombine);
|
||||
} // namespace optiling
|
||||
} // namespace optiling
|
||||
|
||||
@@ -64,8 +64,8 @@ class HcomTopoInfo {
|
||||
~HcomTopoInfo() = default;
|
||||
std::unordered_map<std::string, TopoInfo> rank_info_;
|
||||
std::mutex mutex_;
|
||||
std::unordered_map<std::string, void*> group_to_ordered_stream_; // 通信域保序流
|
||||
std::unordered_map<int32_t, std::unordered_map<std::string, void*>> device_id_to_group_to_ordered_stream_; // 通信域保序流
|
||||
std::unordered_map<std::string, void*> group_to_ordered_stream_; // Ordered stream for the communication domain
|
||||
std::unordered_map<int32_t, std::unordered_map<std::string, void*>> device_id_to_group_to_ordered_stream_; // Ordered stream for the communication domain
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user