add dispath_ffn_combine_bf16 (#5866)
### What this PR does / why we need it?
add dispath_ffn_combine_bf16
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: guanguan0308 <1546542263@qq.com>
This commit is contained in:
@@ -738,7 +738,9 @@ at::Tensor& dispatch_ffn_combine(
|
||||
at::Tensor& out
|
||||
) {
|
||||
char *group_ep_ptr = const_cast<char *>(group.data());
|
||||
EXEC_NPU_CMD(aclnnDispatchFFNCombine,
|
||||
bool is_int8 = weight1[0].dtype() == at::kChar;
|
||||
if (is_int8) {
|
||||
EXEC_NPU_CMD(aclnnDispatchFFNCombine,
|
||||
x,
|
||||
weight1,
|
||||
weight2,
|
||||
@@ -749,6 +751,19 @@ at::Tensor& dispatch_ffn_combine(
|
||||
group_ep_ptr,
|
||||
max_output_size,
|
||||
out);
|
||||
} else {
|
||||
EXEC_NPU_CMD(aclnnDispatchFFNCombineBF16,
|
||||
x,
|
||||
weight1,
|
||||
weight2,
|
||||
expert_idx,
|
||||
scale1,
|
||||
scale2,
|
||||
probs,
|
||||
group_ep_ptr,
|
||||
max_output_size,
|
||||
out);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user