Fix VocabParallelEmbedding UT (#2722)
### What this PR does / why we need it?
Fix VocabParallelEmbedding UT
### How was this patch tested?
CI passed with new added/existing test.
- vLLM version: main
- vLLM main:
f592b3174b
---------
Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
2
.github/workflows/vllm_ascend_test.yaml
vendored
2
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -118,7 +118,7 @@ jobs:
|
|||||||
TORCH_DEVICE_BACKEND_AUTOLOAD: 0
|
TORCH_DEVICE_BACKEND_AUTOLOAD: 0
|
||||||
run: |
|
run: |
|
||||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
|
||||||
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut --ignore=tests/ut/test_platform.py --ignore=tests/ut/ops/test_vocab_parallel_embedding.py
|
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut --ignore=tests/ut/test_platform.py
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
if: ${{ matrix.vllm_version == 'main' }}
|
if: ${{ matrix.vllm_version == 'main' }}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||||
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
|
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
|
||||||
|
|
||||||
@@ -31,6 +32,9 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
|||||||
self.embedding_dim = 10
|
self.embedding_dim = 10
|
||||||
self.org_num_embeddings = 40
|
self.org_num_embeddings = 40
|
||||||
self.padding_size = 8
|
self.padding_size = 8
|
||||||
|
mock_vllm_config = MagicMock()
|
||||||
|
mock_vllm_config.additional_config = {}
|
||||||
|
init_ascend_config(mock_vllm_config)
|
||||||
|
|
||||||
def _create_layer(self):
|
def _create_layer(self):
|
||||||
# Patch methods and dependencies for VocabParallelEmbedding
|
# Patch methods and dependencies for VocabParallelEmbedding
|
||||||
|
|||||||
Reference in New Issue
Block a user