[fix]: fix precision issue in dispatch_ffn_combine_bf16 and remove redundant sync (#7198)
### What this PR does / why we need it?
Fix the precision issue in dispatch_ffn_combine_bf16 operator.
Remove redundant synchronization operations in dispatch_ffn_combine
operator.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: guanguan0308 <1546542263@qq.com>
This commit is contained in:
@@ -41,6 +41,7 @@ namespace {
|
||||
constexpr uint32_t EXPERTID_INDEX = 3;
|
||||
constexpr uint32_t BLOCK_NUM = 20;
|
||||
constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024;
|
||||
constexpr uint64_t MB_SIZE = 1024 * 1024UL;
|
||||
}
|
||||
|
||||
namespace optiling {
|
||||
@@ -240,7 +241,8 @@ static ge::graphStatus DispatchFFNCombineBF16TilingFuncImpl(gert::TilingContext
|
||||
info.maxOutputSize * n2 * sizeof(int16_t) +
|
||||
info.maxOutputSize * info.K * sizeof(int16_t) +
|
||||
info.maxOutputSize * k2 * sizeof(int16_t) +
|
||||
info.worldSize * sizeof(int32_t) * 16;
|
||||
info.worldSize * sizeof(int32_t) * 16 +
|
||||
(info.expertPerRank + info.worldSize) * sizeof(int32_t) * 16;
|
||||
// std::max(info.maxOutputSize * info.N * sizeof(int16_t), info.maxOutputSize * n2 * sizeof(int16_t)) +
|
||||
// std::max(info.maxOutputSize * info.K * sizeof(int8_t), info.maxOutputSize * k2 * sizeof(int8_t));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user