[Bugfix][LoRA][Operator] Fix LoRA custom operators accuracy issue (#2672)
### What this PR does / why we need it?
Fix the LoRA accuracy issue that introduced by custom AscendC operator
"bgmv_shrink, sgmv_shrink, bgmv_expand, sgmv_epand".
The bug details are:
- In the kernel function, if you want to call GlobalTensor.GetSize
method, you have to pass the second parameter of bufferSize when you
call GlobalTensor.SetGlobalBuffer first.
- Or GlobalTensor.GetSize method will return a random value.
- You can refer to [this
doc](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/81RC1alpha002/apiref/ascendcopapi/atlasascendc_api_07_00024.html).
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_ilama_lora.py
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
- vLLM version: v0.10.1.1
- vLLM main:
a344a5aa0a
---------
Signed-off-by: paulyu12 <paulyu0307@gmail.com>
Signed-off-by: paulyu12 <507435917@qq.com>
Co-authored-by: paulyu12 <paulyu0307@gmail.com>
This commit is contained in:
@@ -226,6 +226,7 @@ void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Ten
|
||||
void* x_ptr = x.data_ptr();
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* indices_ptr = indices.data_ptr();
|
||||
int indices_size = indices.size(0);
|
||||
void* y_ptr = y.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
int input_hidden_token = x.size(1);
|
||||
@@ -234,7 +235,7 @@ void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Ten
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("bgmv_shrink");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, input_hidden_token,
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, batch_size, input_hidden_token,
|
||||
lora_rank, scale_f]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
@@ -242,7 +243,7 @@ void bgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Ten
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, batch_size, num_tokens_per_core,
|
||||
bgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, batch_size, num_tokens_per_core,
|
||||
input_hidden_token, lora_rank, scale_f);
|
||||
return 0;
|
||||
});
|
||||
@@ -271,6 +272,7 @@ at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, a
|
||||
void* x_ptr = x.data_ptr();
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* indices_ptr = indices.data_ptr();
|
||||
int indices_size = indices.size(0);
|
||||
void* y_ptr = y.data_ptr();
|
||||
void* y_out_ptr = y_out.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
@@ -279,7 +281,7 @@ at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, a
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("bgmv_expand");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size, lora_rank,
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, y_out_ptr, batch_size, lora_rank,
|
||||
slice_offset, slice_size, output_full_dim]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
@@ -287,7 +289,7 @@ at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, a
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, y_ptr, y_out_ptr, batch_size,
|
||||
bgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, indices_ptr, indices_size, y_ptr, y_out_ptr, batch_size,
|
||||
num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
||||
return 0;
|
||||
});
|
||||
@@ -309,6 +311,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* lora_indices_ptr = lora_indices.data_ptr();
|
||||
void* seq_len_ptr = seq_len.data_ptr();
|
||||
int lora_indices_size = lora_indices.size(0);
|
||||
int seq_len_size = seq_len.size(0);
|
||||
void* y_ptr = y.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
int input_hidden_token = x.size(1);
|
||||
@@ -317,7 +321,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("sgmv_shrink");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr,
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size,
|
||||
seq_len_ptr, seq_len_size, y_ptr,
|
||||
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
@@ -325,7 +330,8 @@ void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, batch_size,
|
||||
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size,
|
||||
y_ptr, batch_size,
|
||||
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
|
||||
return 0;
|
||||
});
|
||||
@@ -352,6 +358,8 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* lora_indices_ptr = lora_indices.data_ptr();
|
||||
void* seq_len_ptr = seq_len.data_ptr();
|
||||
int lora_indices_size = lora_indices.size(0);
|
||||
int seq_len_size = seq_len.size(0);
|
||||
void* y_ptr = y.data_ptr();
|
||||
void* y_out_ptr = y_out.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
@@ -360,7 +368,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("sgmv_expand");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, y_out_ptr,
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
||||
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
@@ -368,7 +376,7 @@ at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indic
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, y_out_ptr,
|
||||
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, lora_indices_size, seq_len_ptr, seq_len_size, y_ptr, y_out_ptr,
|
||||
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
||||
return 0;
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user