[Refactor] Add expert processed token count output for DispatchFFNCombine/DispatchFFNCombineBF16 (#6402)
### What this PR does / why we need it?
Add New Output for Expert Token Count
An additional output tensor expert_token_nums is added to both operators
to meet the requirement of tracking token distribution among experts:
Tensor Name: expert_token_nums
Dimension: 1D tensor
Shape: (local_expert_num,)
Data Type: int32
Semantics: Represents the number of tokens actually received by each
expert on the current card.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: guanguan0308 <1546542263@qq.com>
Signed-off-by: guanguan0308 <162653673+guanguan0308@users.noreply.github.com>
This commit is contained in:
@@ -725,7 +725,7 @@ void batch_matmul_transpose(const at::Tensor &tensor_a, const at::Tensor &tensor
|
||||
return;
|
||||
}
|
||||
|
||||
at::Tensor& dispatch_ffn_combine(
|
||||
std::tuple<at::Tensor&, at::Tensor&> dispatch_ffn_combine(
|
||||
const at::Tensor& x,
|
||||
const at::TensorList& weight1,
|
||||
const at::TensorList& weight2,
|
||||
@@ -735,7 +735,8 @@ at::Tensor& dispatch_ffn_combine(
|
||||
const at::Tensor& probs,
|
||||
c10::string_view group,
|
||||
int64_t max_output_size,
|
||||
at::Tensor& out
|
||||
at::Tensor& out,
|
||||
at::Tensor& expert_token_nums
|
||||
) {
|
||||
char *group_ep_ptr = const_cast<char *>(group.data());
|
||||
bool is_int8 = weight1[0].dtype() == at::kChar;
|
||||
@@ -750,7 +751,8 @@ at::Tensor& dispatch_ffn_combine(
|
||||
probs,
|
||||
group_ep_ptr,
|
||||
max_output_size,
|
||||
out);
|
||||
out,
|
||||
expert_token_nums);
|
||||
} else {
|
||||
EXEC_NPU_CMD(aclnnDispatchFFNCombineBF16,
|
||||
x,
|
||||
@@ -762,9 +764,10 @@ at::Tensor& dispatch_ffn_combine(
|
||||
probs,
|
||||
group_ep_ptr,
|
||||
max_output_size,
|
||||
out);
|
||||
out,
|
||||
expert_token_nums);
|
||||
}
|
||||
return out;
|
||||
return {out, expert_token_nums};
|
||||
}
|
||||
|
||||
at::Tensor npu_lightning_indexer(
|
||||
@@ -1452,7 +1455,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
ops.def(
|
||||
"dispatch_ffn_combine(Tensor x, Tensor[] weight1, Tensor[] weight2, Tensor expert_idx,"
|
||||
" Tensor[] scale1, Tensor[] scale2, Tensor probs, str group,"
|
||||
" int max_output_size, Tensor! out) -> Tensor"
|
||||
" int max_output_size, Tensor! out, Tensor! expert_token_nums) -> (Tensor out, Tensor expert_token_nums)"
|
||||
);
|
||||
ops.impl("dispatch_ffn_combine", torch::kPrivateUse1, &vllm_ascend::dispatch_ffn_combine);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user