[MISC] fix format check error (#654)
This pr makes format.sh works as expect. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -29,7 +29,7 @@ using vllm_ascend::AccType;
|
|||||||
using vllm_ascend::local_mem_copy;
|
using vllm_ascend::local_mem_copy;
|
||||||
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
|
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
|
||||||
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
|
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
|
||||||
// retrive this size from runtime for more Soc support
|
// retrieve this size from runtime for more Soc support
|
||||||
static int constexpr loadSize = 512;
|
static int constexpr loadSize = 512;
|
||||||
using dst_t = scalar_t;
|
using dst_t = scalar_t;
|
||||||
using acc_t = typename AccType<scalar_t>::type;
|
using acc_t = typename AccType<scalar_t>::type;
|
||||||
@@ -66,7 +66,7 @@ public:
|
|||||||
pipe_->InitBuffer(inQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
|
pipe_->InitBuffer(inQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
|
||||||
pipe_->InitBuffer(inQueSinCos_, 1 /* buffer_num */, rotDim_ * sizeof(scalar_t) /* buffer_size */);
|
pipe_->InitBuffer(inQueSinCos_, 1 /* buffer_num */, rotDim_ * sizeof(scalar_t) /* buffer_size */);
|
||||||
pipe_->InitBuffer(outQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
|
pipe_->InitBuffer(outQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
|
||||||
// 2 temperary calculation buffer
|
// 2 temporary calculation buffer
|
||||||
calcTmpBufferOffset_ = 0;
|
calcTmpBufferOffset_ = 0;
|
||||||
// 1 upcast buffer for bf16 (headSize)
|
// 1 upcast buffer for bf16 (headSize)
|
||||||
upcastInputBufferOffset_ = calcTmpBufferOffset_ + sizeof(acc_t) * embedDim_ * 2;
|
upcastInputBufferOffset_ = calcTmpBufferOffset_ + sizeof(acc_t) * embedDim_ * 2;
|
||||||
@@ -75,10 +75,10 @@ public:
|
|||||||
// 2 sin cos upcast buffer for bf16
|
// 2 sin cos upcast buffer for bf16
|
||||||
cosSinUpcastBufferOffset_ = upcastTempBufferOffset_ + sizeof(acc_t) * 2 * embedDim_;
|
cosSinUpcastBufferOffset_ = upcastTempBufferOffset_ + sizeof(acc_t) * 2 * embedDim_;
|
||||||
// 2. bf16 path: needs 2 cos sin upcast buffer size
|
// 2. bf16 path: needs 2 cos sin upcast buffer size
|
||||||
// 3. fp16 path: needs 2 temperary calculation buffer size
|
// 3. fp16 path: needs 2 temporary calculation buffer size
|
||||||
tempBufferSize_ = cosSinUpcastBufferOffset_ + 2 * embedDim_ * sizeof(acc_t);
|
tempBufferSize_ = cosSinUpcastBufferOffset_ + 2 * embedDim_ * sizeof(acc_t);
|
||||||
// need to consider upcast the bf16 to fp32, so we might need 4 buffer just in case
|
// need to consider upcast the bf16 to fp32, so we might need 4 buffer just in case
|
||||||
// 2 temperary buffer, 2 input buffer, 1 cos buffer, 1 sin buffer, 2 scale buffer (headSize), 2 zp
|
// 2 temporary buffer, 2 input buffer, 1 cos buffer, 1 sin buffer, 2 scale buffer (headSize), 2 zp
|
||||||
// buffer(headSize int8), 1 dst_temp buffer(headSize, int32)
|
// buffer(headSize int8), 1 dst_temp buffer(headSize, int32)
|
||||||
pipe_->InitBuffer(calcBuf_, tempBufferSize_ /* buffer_size */);
|
pipe_->InitBuffer(calcBuf_, tempBufferSize_ /* buffer_size */);
|
||||||
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ Currently, w8a8 quantization is already supported by vllm-ascend originally on v
|
|||||||
|
|
||||||
Currently, w8a8 DeepSeek is working in process: [support AscendW8A8 quantization](https://github.com/vllm-project/vllm-ascend/pull/511)
|
Currently, w8a8 DeepSeek is working in process: [support AscendW8A8 quantization](https://github.com/vllm-project/vllm-ascend/pull/511)
|
||||||
|
|
||||||
Please run DeepSeek with BF16 now, follwing the [Multi-Node DeepSeek inferencing tutorail](https://vllm-ascend.readthedocs.io/en/main/tutorials/multi_node.html)
|
Please run DeepSeek with BF16 now, following the [Multi-Node DeepSeek inferencing tutorail](https://vllm-ascend.readthedocs.io/en/main/tutorials/multi_node.html)
|
||||||
|
|
||||||
### 12. There is not output in log when loading models using vllm-ascend, How to solve it?
|
### 12. There is not output in log when loading models using vllm-ascend, How to solve it?
|
||||||
|
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ apt update -y
|
|||||||
apt install -y gcc g++ cmake libnuma-dev wget git
|
apt install -y gcc g++ cmake libnuma-dev wget git
|
||||||
```
|
```
|
||||||
|
|
||||||
**[Optinal]** Config the extra-index of `pip` if you are working on a **x86** machine, so that the torch with cpu could be found:
|
**[Optional]** Config the extra-index of `pip` if you are working on a **x86** machine, so that the torch with cpu could be found:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip config set global.extra-index-url https://download.pytorch.org/whl/cpu/
|
pip config set global.extra-index-url https://download.pytorch.org/whl/cpu/
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ def run_decode(prefill_done):
|
|||||||
gpu_memory_utilization=0.8,
|
gpu_memory_utilization=0.8,
|
||||||
tensor_parallel_size=2)
|
tensor_parallel_size=2)
|
||||||
|
|
||||||
# Wait for the producer to start the comsumer
|
# Wait for the producer to start the consumer
|
||||||
print("Waiting for prefill node to finish...")
|
print("Waiting for prefill node to finish...")
|
||||||
prefill_done.wait()
|
prefill_done.wait()
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ def main(args):
|
|||||||
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
|
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
|
||||||
if args.num_prompts > 1:
|
if args.num_prompts > 1:
|
||||||
# Batch inference
|
# Batch inference
|
||||||
inputs = [inputs] * args.num_prompts
|
inputs = [inputs] * args.num_prompts # type: ignore
|
||||||
|
|
||||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||||
|
|
||||||
|
|||||||
33
format.sh
33
format.sh
@@ -116,6 +116,7 @@ format_all() {
|
|||||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" .
|
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" .
|
||||||
}
|
}
|
||||||
|
|
||||||
|
echo 'vllm-ascend yapf:'
|
||||||
## This flag formats individual files. --files *must* be the first command line
|
## This flag formats individual files. --files *must* be the first command line
|
||||||
## arg to use this option.
|
## arg to use this option.
|
||||||
if [[ "$1" == '--files' ]]; then
|
if [[ "$1" == '--files' ]]; then
|
||||||
@@ -128,12 +129,12 @@ else
|
|||||||
# Format only the files that changed in last commit.
|
# Format only the files that changed in last commit.
|
||||||
format_changed
|
format_changed
|
||||||
fi
|
fi
|
||||||
echo 'vLLM yapf: Done'
|
echo 'vllm-ascend yapf: Done'
|
||||||
|
|
||||||
# Run mypy
|
# Run mypy
|
||||||
echo 'vLLM mypy:'
|
echo 'vllm-ascend mypy:'
|
||||||
tools/mypy.sh
|
tools/mypy.sh
|
||||||
echo 'vLLM mypy: Done'
|
echo 'vllm-ascend mypy: Done'
|
||||||
|
|
||||||
|
|
||||||
# If git diff returns a file that is in the skip list, the file may be checked anyway:
|
# If git diff returns a file that is in the skip list, the file may be checked anyway:
|
||||||
@@ -172,6 +173,7 @@ spell_check_changed() {
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
echo 'vllm-ascend codespell:'
|
||||||
# Run Codespell
|
# Run Codespell
|
||||||
## This flag runs spell check of individual files. --files *must* be the first command line
|
## This flag runs spell check of individual files. --files *must* be the first command line
|
||||||
## arg to use this option.
|
## arg to use this option.
|
||||||
@@ -185,7 +187,7 @@ else
|
|||||||
# Check spelling only of the files that changed in last commit.
|
# Check spelling only of the files that changed in last commit.
|
||||||
spell_check_changed
|
spell_check_changed
|
||||||
fi
|
fi
|
||||||
echo 'vLLM codespell: Done'
|
echo 'vllm-ascend codespell: Done'
|
||||||
|
|
||||||
|
|
||||||
# Lint specified files
|
# Lint specified files
|
||||||
@@ -211,6 +213,7 @@ lint_changed() {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
echo 'vllm-ascend ruff:'
|
||||||
# Run Ruff
|
# Run Ruff
|
||||||
### This flag lints individual files. --files *must* be the first command line
|
### This flag lints individual files. --files *must* be the first command line
|
||||||
### arg to use this option.
|
### arg to use this option.
|
||||||
@@ -224,7 +227,7 @@ else
|
|||||||
# Format only the files that changed in last commit.
|
# Format only the files that changed in last commit.
|
||||||
lint_changed
|
lint_changed
|
||||||
fi
|
fi
|
||||||
echo 'vLLM ruff: Done'
|
echo 'vllm-ascend ruff: Done'
|
||||||
|
|
||||||
# check spelling of specified files
|
# check spelling of specified files
|
||||||
isort_check() {
|
isort_check() {
|
||||||
@@ -251,6 +254,7 @@ isort_check_changed() {
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
echo 'vllm-ascend isort:'
|
||||||
# Run Isort
|
# Run Isort
|
||||||
# This flag runs spell check of individual files. --files *must* be the first command line
|
# This flag runs spell check of individual files. --files *must* be the first command line
|
||||||
# arg to use this option.
|
# arg to use this option.
|
||||||
@@ -264,18 +268,13 @@ else
|
|||||||
# Check spelling only of the files that changed in last commit.
|
# Check spelling only of the files that changed in last commit.
|
||||||
isort_check_changed
|
isort_check_changed
|
||||||
fi
|
fi
|
||||||
echo 'vLLM isort: Done'
|
echo 'vllm-ascend isort: Done'
|
||||||
|
|
||||||
# Clang-format section
|
# Clang-format section
|
||||||
# Exclude some files for formatting because they are vendored
|
# Exclude some files for formatting because they are vendored
|
||||||
# NOTE: Keep up to date with .github/workflows/clang-format.yml
|
# NOTE: Keep up to date with .github/workflows/clang-format.yml
|
||||||
CLANG_FORMAT_EXCLUDES=(
|
CLANG_FORMAT_EXCLUDES=(
|
||||||
'csrc/moe/topk_softmax_kernels.cu'
|
'csrc/kernels/pos_encoding_kernels.cpp'
|
||||||
'csrc/quantization/gguf/ggml-common.h'
|
|
||||||
'csrc/quantization/gguf/dequantize.cuh'
|
|
||||||
'csrc/quantization/gguf/vecdotq.cuh'
|
|
||||||
'csrc/quantization/gguf/mmq.cuh'
|
|
||||||
'csrc/quantization/gguf/mmvq.cuh'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format specified files with clang-format
|
# Format specified files with clang-format
|
||||||
@@ -315,15 +314,15 @@ elif [[ "$1" == '--all' ]]; then
|
|||||||
else
|
else
|
||||||
clang_format_changed
|
clang_format_changed
|
||||||
fi
|
fi
|
||||||
echo 'vLLM clang-format: Done'
|
echo 'vllm-ascend clang-format: Done'
|
||||||
|
|
||||||
echo 'vLLM actionlint:'
|
echo 'vllm-ascend actionlint:'
|
||||||
tools/actionlint.sh -color
|
tools/actionlint.sh -color
|
||||||
echo 'vLLM actionlint: Done'
|
echo 'vllm-ascend actionlint: Done'
|
||||||
|
|
||||||
echo 'vLLM shellcheck:'
|
echo 'vllm-ascend shellcheck:'
|
||||||
tools/shellcheck.sh
|
tools/shellcheck.sh
|
||||||
echo 'vLLM shellcheck: Done'
|
echo 'vllm-ascend shellcheck: Done'
|
||||||
|
|
||||||
echo 'excalidraw png check:'
|
echo 'excalidraw png check:'
|
||||||
tools/png-lint.sh
|
tools/png-lint.sh
|
||||||
|
|||||||
@@ -236,7 +236,7 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# TODO: There is a problem with the preemptive scheduling in the current
|
# TODO: There is a problem with the preemptive scheduling in the current
|
||||||
# version, which makes this case fail. Please release this case after the
|
# version, which makes this case fail. Please release this case after the
|
||||||
# preemptive scheduling preblem is solved.
|
# preemptive scheduling problem is solved.
|
||||||
# @pytest.mark.parametrize(
|
# @pytest.mark.parametrize(
|
||||||
# "common_llm_kwargs",
|
# "common_llm_kwargs",
|
||||||
# [{
|
# [{
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# TODO: There is a problem with the preemptive scheduling in the current
|
# TODO: There is a problem with the preemptive scheduling in the current
|
||||||
# version, which makes this case fail. Please release this case after the
|
# version, which makes this case fail. Please release this case after the
|
||||||
# preemptive scheduling preblem is solved.
|
# preemptive scheduling problem is solved.
|
||||||
# @pytest.mark.parametrize(
|
# @pytest.mark.parametrize(
|
||||||
# "common_llm_kwargs",
|
# "common_llm_kwargs",
|
||||||
# [{
|
# [{
|
||||||
@@ -352,7 +352,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# TODO: There is a problem with the preemptive scheduling in the current
|
# TODO: There is a problem with the preemptive scheduling in the current
|
||||||
# version, which makes this case fail. Please release this case after the
|
# version, which makes this case fail. Please release this case after the
|
||||||
# preemptive scheduling preblem is solved.
|
# preemptive scheduling problem is solved.
|
||||||
# @pytest.mark.parametrize(
|
# @pytest.mark.parametrize(
|
||||||
# "common_llm_kwargs",
|
# "common_llm_kwargs",
|
||||||
# [{
|
# [{
|
||||||
|
|||||||
@@ -175,7 +175,7 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
|
|||||||
|
|
||||||
# TODO: There is a problem with the preemptive scheduling in the current
|
# TODO: There is a problem with the preemptive scheduling in the current
|
||||||
# version, which makes this case fail. Please release this case after the
|
# version, which makes this case fail. Please release this case after the
|
||||||
# preemptive scheduling preblem is solved.
|
# preemptive scheduling problem is solved.
|
||||||
# @pytest.mark.parametrize(
|
# @pytest.mark.parametrize(
|
||||||
# "common_llm_kwargs",
|
# "common_llm_kwargs",
|
||||||
# [{
|
# [{
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/tools
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/tools
|
||||||
#
|
#
|
||||||
|
export SHELLCHECK_OPTS="--exclude=SC2046,SC2006"
|
||||||
|
|
||||||
if command -v actionlint &> /dev/null; then
|
if command -v actionlint &> /dev/null; then
|
||||||
actionlint .github/workflows/*.yml .github/workflows/*.yaml
|
actionlint .github/workflows/*.yml .github/workflows/*.yaml
|
||||||
@@ -29,4 +30,4 @@ fi
|
|||||||
|
|
||||||
# download a binary to the current directory - v1.7.3
|
# download a binary to the current directory - v1.7.3
|
||||||
bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/aa0a7be8e566b096e64a5df8ff290ec24fa58fbc/scripts/download-actionlint.bash)
|
bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/aa0a7be8e566b096e64a5df8ff290ec24fa58fbc/scripts/download-actionlint.bash)
|
||||||
./actionlint .github/workflows/*.yml .github/workflows/*.yaml
|
./actionlint .github/workflows/*.yml .github/workflows/*.yaml
|
||||||
|
|||||||
@@ -28,11 +28,7 @@ fi
|
|||||||
|
|
||||||
run_mypy() {
|
run_mypy() {
|
||||||
echo "Running mypy on $1"
|
echo "Running mypy on $1"
|
||||||
if [ "$CI" -eq 1 ] && [ -z "$1" ]; then
|
mypy --check-untyped-defs --follow-imports skip --python-version "${PYTHON_VERSION}" "$@"
|
||||||
mypy --python-version "${PYTHON_VERSION}" "$@"
|
|
||||||
return
|
|
||||||
fi
|
|
||||||
mypy --follow-imports skip --python-version "${PYTHON_VERSION}" "$@"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
run_mypy vllm_ascend
|
run_mypy vllm_ascend
|
||||||
|
|||||||
@@ -1080,7 +1080,7 @@ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
|||||||
if len(kv_cache) > 0 and kv_cache[0].numel(
|
if len(kv_cache) > 0 and kv_cache[0].numel(
|
||||||
) > 0 and attn_metadata.num_prefills > 0:
|
) > 0 and attn_metadata.num_prefills > 0:
|
||||||
slots = attn_metadata.slot_mapping
|
slots = attn_metadata.slot_mapping
|
||||||
# NOTE: Seperate the kv cache in advance to avoid OOM or other issues
|
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
||||||
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
||||||
num_tokens, self.num_kv_heads, -1),
|
num_tokens, self.num_kv_heads, -1),
|
||||||
value=k_pe,
|
value=k_pe,
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class AscendSchedulerConfig(SchedulerConfig):
|
|||||||
)
|
)
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"currently AscendScheduler only supports LLM modles.")
|
"currently AscendScheduler only supports LLM models.")
|
||||||
if self.num_scheduler_steps > 1:
|
if self.num_scheduler_steps > 1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"currently AscendScheduler doesn't support multi-step.")
|
"currently AscendScheduler doesn't support multi-step.")
|
||||||
|
|||||||
@@ -57,8 +57,10 @@ def get_device_ips():
|
|||||||
universal_newlines=True)
|
universal_newlines=True)
|
||||||
if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH):
|
if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH):
|
||||||
raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.")
|
raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.")
|
||||||
npu_start_idx = int(
|
re_result = re.match(r'.*\n\t([0-9]+).*', npu_info.stdout)
|
||||||
re.match(r'.*\n\t([0-9]+).*', npu_info.stdout).group(1))
|
if re_result is None:
|
||||||
|
raise RuntimeError("Can't find npu start index")
|
||||||
|
npu_start_idx = int(re_result.group(1))
|
||||||
device_ip_list = []
|
device_ip_list = []
|
||||||
for ip_offset in range(world_size):
|
for ip_offset in range(world_size):
|
||||||
cmd = [
|
cmd = [
|
||||||
@@ -68,7 +70,10 @@ def get_device_ips():
|
|||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
universal_newlines=True)
|
universal_newlines=True)
|
||||||
device_ip = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout).group(1)
|
re_result = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout)
|
||||||
|
if re_result is None:
|
||||||
|
raise RuntimeError("Can't find npu ip")
|
||||||
|
device_ip = re_result.group(1)
|
||||||
device_ip_list.append(device_ip)
|
device_ip_list.append(device_ip)
|
||||||
return device_ip_list
|
return device_ip_list
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
|
|||||||
# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for
|
# vllm-ascend will maintain its own EP GroupCoordinator and ETP GroupCoordinator for
|
||||||
# customize parallel solution
|
# customize parallel solution
|
||||||
_EP: Optional[GroupCoordinator] = None
|
_EP: Optional[GroupCoordinator] = None
|
||||||
_ETP: Optional[list[GroupCoordinator]] = None
|
_ETP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
|
||||||
def get_ep_group() -> GroupCoordinator:
|
def get_ep_group() -> GroupCoordinator:
|
||||||
@@ -69,4 +69,4 @@ def destory_ascend_model_parallel():
|
|||||||
global _ETP
|
global _ETP
|
||||||
if _ETP:
|
if _ETP:
|
||||||
_ETP.destroy()
|
_ETP.destroy()
|
||||||
_ETP = None
|
_ETP = None
|
||||||
|
|||||||
@@ -278,7 +278,7 @@ def fused_experts(
|
|||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||||
# This created multiple NaN and index_add_ will mix them up which harms accracy
|
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
||||||
# remove this mask and filter after it being fixed
|
# remove this mask and filter after it being fixed
|
||||||
num_valid_tokens = mask.sum()
|
num_valid_tokens = mask.sum()
|
||||||
valid_token_mask = torch.arange(
|
valid_token_mask = torch.arange(
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|||||||
persistent=False)
|
persistent=False)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Patch when aclnn ops avaiable
|
# TODO: Patch when aclnn ops available
|
||||||
RotaryEmbedding.forward_oot = rope_forward_oot
|
RotaryEmbedding.forward_oot = rope_forward_oot
|
||||||
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
|
DeepseekScalingRotaryEmbedding.forward = native_rope_deepseek_forward
|
||||||
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
|
DeepseekScalingRotaryEmbedding._set_cos_sin_cache = _set_cos_sin_cache
|
||||||
|
|||||||
Reference in New Issue
Block a user