Add Custom Kernels For LoRA Performance (#1884)

### What this PR does / why we need it?
Add two custom kernels(bgmv_shrink and bgmv expand) to solve the
performance of LoRA
### Does this PR introduce _any_ user-facing change?
no user-facing change
### How was this patch tested?
we add Unit Test file to test the custom ascendc kernel. See
vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py and
vllm-ascend/tests/e2e/singlecard/ops/test_bgmv_expand.py
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, the TTFT, TPOT and throughput have increased by
about 70%.

- vLLM version: v0.9.2
- vLLM main:
40d86ee412

---------

Signed-off-by: taoxudonghaha <justsheldon@163.com>
This commit is contained in:
taoxudonghaha
2025-07-29 19:27:50 +08:00
committed by GitHub
parent 2da281ec5a
commit 540336edc9
8 changed files with 946 additions and 3 deletions

View File

@@ -199,6 +199,90 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
cmd.Run();
return {masked_input, mask};
}
void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y, double scale)
{
at::ScalarType scalar_type = x.scalar_type();
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
"the first dimension of x, y, indices should be same");
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
void* x_ptr = x.data_ptr();
void* weight_ptr = weight.data_ptr();
void* indices_ptr = indices.data_ptr();
void* y_ptr = y.data_ptr();
int batch_size = x.size(0);
int input_hidden_token = x.size(1);
uint32_t lora_rank = y.size(1);
float scale_f = static_cast<float>(scale);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("bgmv_shrink");
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, input_hidden_token,
lora_rank, scale_f]() -> int {
auto dtype = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, num_tokens_per_core,
input_hidden_token, lora_rank, scale_f);
return 0;
});
cmd.Run();
return;
}
at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y,
int64_t slice_offset, int64_t slice_size)
{
at::ScalarType scalar_type = y.scalar_type();
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
TORCH_CHECK(indices.dim() == 1, "indices should be [batch_size]");
TORCH_CHECK(x.size(0) == y.size(0) && x.size(0) == indices.size(0),
"the first dimension of x, y, indices should be same");
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
"slice_size + slice_offset should be smaller than the second dimension of y")
at::Tensor y_out = y;
void* x_ptr = x.data_ptr();
void* weight_ptr = weight.data_ptr();
void* indices_ptr = indices.data_ptr();
void* y_ptr = y.data_ptr();
void* y_out_ptr = y_out.data_ptr();
int batch_size = x.size(0);
int lora_rank = x.size(1);
int output_full_dim = y.size(1);
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("bgmv_expand");
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size, lora_rank,
slice_offset, slice_size, output_full_dim]() -> int {
auto dtype = get_dtype_from_torch(scalar_type);
int device_id = 0;
int64_t aiv_num = 0;
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size,
num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
return 0;
});
cmd.Run();
return y_out;
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(_C, ops)
@@ -223,6 +307,14 @@ TORCH_LIBRARY_EXPAND(_C, ops)
" int added_vocab_start_index, "
" int added_vocab_end_index) -> (Tensor masked_input, Tensor mask)");
ops.impl("get_masked_input_and_mask", torch::kPrivateUse1, &vllm_ascend::get_masked_input_and_mask);
ops.def("bgmv_shrink(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y, float scale) -> ()");
ops.impl("bgmv_shrink", torch::kPrivateUse1, &vllm_ascend::bgmv_shrink);
ops.def(
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
" int slice_offset, int slice_size) -> Tensor");
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
}
REGISTER_EXTENSION(_C)