[Perf][1/N] w8a8c8 support in dsv3.2/glm5 (#7029)
### What this PR does / why we need it?
This PR supports W8A8C8 in dsv3.2/glm5 with lightning_indexer_quant ops
in pd-mix stage mainly.
Because the code for the current PD-disaggregated scenario is still
under refactoring and cleanup, this PR prioritizes ensuring the C8
functionality in the pd-mix scenario.
The next steps are planned in two parts:
① Once the optimized scatter operator is updated, we will replace the
original operator to improve the performance of storing k_scale.
② Once the code logic for the PD-disaggregated scenario becomes stable,
we will carry out more comprehensive validation and make appropriate
adaptations.
③ Because enabling C8 currently introduces several new operators whose
performance still needs improvement, performance may regress in some
scenarios. Therefore, only after all the operators are fully ready can
we ensure that this feature does not cause any performance degradation.
At that point, we will enable this feature by default and remove the
switch in `additional_config`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -529,6 +529,44 @@ std::vector<at::Tensor> moe_grouped_matmul_meta(
|
||||
return y;
|
||||
}
|
||||
|
||||
at::Tensor npu_lightning_indexer_quant_meta(
|
||||
const at::Tensor &query, const at::Tensor &key, const at::Tensor &weights,
|
||||
const at::Tensor &query_dequant_scale, const at::Tensor &key_dequant_scale,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_query,
|
||||
const c10::optional<at::Tensor> &actual_seq_lengths_key,
|
||||
const c10::optional<at::Tensor> &block_table, int64_t query_quant_mode, int64_t key_quant_mode,
|
||||
c10::string_view layout_query, c10::string_view layout_key, int64_t sparse_count, int64_t sparse_mode)
|
||||
{
|
||||
std::string query_layout_str = std::string(layout_query);
|
||||
std::string key_layout_str = std::string(layout_key);
|
||||
|
||||
const int SIZE = 8;
|
||||
const int DIM_0 = 0;
|
||||
const int DIM_1 = 1;
|
||||
const int DIM_2 = 2;
|
||||
const int DIM_3 = 3;
|
||||
|
||||
at::SmallVector<int64_t, SIZE> output_size;
|
||||
for (size_t i = 0; i < query.sizes().size(); i++) {
|
||||
TORCH_CHECK(query.size(i) > 0, "All values within query's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", query.size(i));
|
||||
}
|
||||
for (size_t i = 0; i < key.sizes().size(); i++) {
|
||||
TORCH_CHECK(key.size(i) > 0, "All values within key's shape should be greater "
|
||||
"than 0, but shape[", i, "] is ", key.size(i));
|
||||
}
|
||||
TORCH_CHECK(sparse_count > 0, "sparse count should be greater than 0, but now is ", sparse_count);
|
||||
int64_t keyHeadNum = (key_layout_str == "TND")? key.size(DIM_1) : key.size(DIM_2);
|
||||
if (query_layout_str == "BSND") {
|
||||
output_size = {query.size(DIM_0), query.size(DIM_1), keyHeadNum, sparse_count};
|
||||
} else {
|
||||
output_size = {query.size(DIM_0), keyHeadNum, sparse_count};
|
||||
}
|
||||
at::Tensor lightning_indexer_quant_output = at::empty(output_size, query.options().dtype(at::kInt));
|
||||
|
||||
return lightning_indexer_quant_output;
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
|
||||
@@ -576,5 +614,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) {
|
||||
ops.impl("causal_conv1d_fn", &vllm_ascend::meta::causal_conv1d_fn_meta);
|
||||
// moe_grouped_matmul
|
||||
ops.impl("moe_grouped_matmul", &vllm_ascend::meta::moe_grouped_matmul_meta);
|
||||
// Lightning indexer quant
|
||||
ops.impl("npu_lightning_indexer_quant", &vllm_ascend::meta::npu_lightning_indexer_quant_meta);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user