2025-07-02 16:40:51 +08:00
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from vllm.attention.layer import Attention
|
|
|
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
[V1] MTP supports torchair (#2145)
### What this PR does / why we need it?
Support MTP with:
- [x] V0 Scheduler
- [x] TorchAir
- [x] Single DP
- [x] Multi DP
- [x] Disaggregate PD
Known issues:
- [ ] Not support V1 Scheduler (chunked prefill), will be supported in a
few weeks
- [ ] vllm v0.10.0 does not support metrics with `DP > 1` right now,
need to comment out the line 171-175 in file
`vllm/vllm/v1/metrics/loggers.py`
```
if (len(self.engine_indexes) > 1
and vllm_config.speculative_config is not None):
raise NotImplementedError("Prometheus metrics with Spec Decoding "
"with >1 EngineCore per AsyncLLM is not "
"supported yet.")
```
To start an online server with torchair enabled, here is an example:
```
python -m vllm.entrypoints.openai.api_server \
--model="/weights/DeepSeek-R1_w8a8/" \
--trust-remote-code \
--max-model-len 40000 \
--tensor-parallel-size 4 \
--data_parallel_size 4 \
--max-num-seqs 16 \
--no-enable-prefix-caching \
--enable_expert_parallel \
--served-model-name deepseekr1 \
--speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \
--quantization ascend \
--host 0.0.0.0 \
--port 1234 \
--additional-config '{"ascend_scheduler_config":{"enabled":true,"enable_chunked_prefill":false},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]},"enable_weight_nz_layout":true}' \
--gpu_memory_utilization 0.9
```
offline example with torchair enabled
```
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=16, temperature=0)
# Create an LLM.
llm = LLM(
model="/home/data/DeepSeek-R1_w8a8/",
tensor_parallel_size=16,
max_num_seqs=16,
gpu_memory_utilization=0.9,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
trust_remote_code=True,
enforce_eager=False,
max_model_len=2000,
additional_config = {
'torchair_graph_config': {
'enabled': True,
"graph_batch_sizes": [16],
'enable_multistream_shared_expert': False,
},
"ascend_scheduler_config": {
"enabled": True
},
# 'expert_tensor_parallel_size': 16,
}
)
# Generate texts from the prompts.
# llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
# llm.stop_profile()
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.10.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/302962e806e9820643ae25987e8e38ed035e05d3
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
2025-08-06 19:37:43 +08:00
|
|
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
2025-10-14 17:39:26 +08:00
|
|
|
from vllm.model_executor.layers.linear import LinearBase
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
from tests.ut.base import TestBase
|
2025-10-14 17:39:26 +08:00
|
|
|
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
2025-07-02 16:40:51 +08:00
|
|
|
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
|
|
|
|
|
AscendQuantConfig)
|
2025-08-26 09:06:16 +08:00
|
|
|
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAscendQuantConfig(TestBase):
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.sample_config = {
|
|
|
|
|
"weight": "INT8",
|
|
|
|
|
"fa_quant_type": "C8",
|
|
|
|
|
"kv_quant_type": "C8",
|
|
|
|
|
"layer1.weight": "INT8",
|
|
|
|
|
"layer2.weight": "FLOAT",
|
|
|
|
|
"fused_layer.weight": "FLOAT",
|
|
|
|
|
"fused_layer.shard1.weight": "FLOAT",
|
|
|
|
|
"fused_layer.shard2.weight": "FLOAT",
|
|
|
|
|
"shard1.weight": "FLOAT",
|
|
|
|
|
"shard2.weight": "FLOAT",
|
|
|
|
|
}
|
|
|
|
|
self.ascend_config = AscendQuantConfig(self.sample_config)
|
|
|
|
|
self.ascend_config.packed_modules_mapping = None
|
|
|
|
|
|
|
|
|
|
def test_init(self):
|
|
|
|
|
self.assertEqual(self.ascend_config.quant_description,
|
|
|
|
|
self.sample_config)
|
|
|
|
|
|
|
|
|
|
def test_repr(self):
|
|
|
|
|
repr_str = repr(self.ascend_config)
|
|
|
|
|
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
|
|
|
|
|
|
|
|
|
|
def test_get_name(self):
|
|
|
|
|
self.assertEqual(AscendQuantConfig.get_name(),
|
2025-08-26 09:06:16 +08:00
|
|
|
ASCEND_QUANTIZATION_METHOD)
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
def test_get_supported_act_dtypes(self):
|
|
|
|
|
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
|
|
|
|
|
self.assertEqual(len(supported_dtypes), 3)
|
|
|
|
|
|
|
|
|
|
def test_get_min_capability(self):
|
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
|
|
|
AscendQuantConfig.get_min_capability()
|
|
|
|
|
|
|
|
|
|
def test_get_config_filenames(self):
|
|
|
|
|
filenames = AscendQuantConfig.get_config_filenames()
|
|
|
|
|
self.assertEqual(filenames, ["quant_model_description.json"])
|
|
|
|
|
|
|
|
|
|
def test_from_config(self):
|
|
|
|
|
config = AscendQuantConfig.from_config(self.sample_config)
|
|
|
|
|
self.assertIsInstance(config, AscendQuantConfig)
|
|
|
|
|
self.assertEqual(config.quant_description, self.sample_config)
|
|
|
|
|
|
|
|
|
|
@patch('torch.npu.is_available')
|
|
|
|
|
def test_override_quantization_method(self, mock_is_available):
|
|
|
|
|
# Test when NPU is available
|
|
|
|
|
mock_is_available.return_value = True
|
|
|
|
|
result = AscendQuantConfig.override_quantization_method(None, None)
|
[Quantization] Support compressed tensors w8a8 static and w8a8 dynamic weight (#4036)
### What this PR does / why we need it?
While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.
1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.
Co-authored-by: taoqun110 taoqun@huawei.com
Co-authored-by: chenxi-hh chen464822955@163.com
- vLLM version: v0.11.2
---------
Signed-off-by: LHXuuu <scut_xlh@163.com>
Signed-off-by: chenxi-hh <chen464822955@163.com>
Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
Co-authored-by: chenxi-hh <chen464822955@163.com>
Co-authored-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
2025-11-28 14:09:39 +08:00
|
|
|
self.assertIsNone(result)
|
2025-12-09 23:42:01 +08:00
|
|
|
hf_quant_cfg = {"quant_method": ""}
|
|
|
|
|
result = AscendQuantConfig.override_quantization_method(
|
|
|
|
|
hf_quant_cfg, None)
|
|
|
|
|
self.assertEqual(result, "ascend")
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
# Test when NPU is not available
|
|
|
|
|
mock_is_available.return_value = False
|
|
|
|
|
result = AscendQuantConfig.override_quantization_method(None, None)
|
|
|
|
|
self.assertIsNone(result)
|
2025-12-09 23:42:01 +08:00
|
|
|
hf_quant_cfg = {"quant_method": ""}
|
|
|
|
|
result = AscendQuantConfig.override_quantization_method(
|
|
|
|
|
hf_quant_cfg, None)
|
|
|
|
|
self.assertIsNone(result)
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
def test_get_quant_method_for_linear(self):
|
2025-09-19 20:50:14 +08:00
|
|
|
mock_config = MagicMock()
|
|
|
|
|
mock_config.model_config.hf_config.model_type = None
|
2025-07-02 16:40:51 +08:00
|
|
|
linear_layer = MagicMock(spec=LinearBase)
|
|
|
|
|
# Test skipped layer
|
2025-09-19 20:50:14 +08:00
|
|
|
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
|
|
|
|
patch.object(self.ascend_config, \
|
2025-07-02 16:40:51 +08:00
|
|
|
'is_layer_skipped_ascend',
|
|
|
|
|
return_value=True):
|
|
|
|
|
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
2025-10-14 17:39:26 +08:00
|
|
|
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
# Test quantized layer
|
|
|
|
|
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
2025-09-19 20:50:14 +08:00
|
|
|
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
2025-07-02 16:40:51 +08:00
|
|
|
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
|
|
|
|
|
|
|
|
|
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
|
|
|
|
self.assertIs(method, mock_ascend_linear.return_value)
|
|
|
|
|
mock_ascend_linear.assert_called_once_with(
|
|
|
|
|
self.ascend_config, ".attn",
|
[Quantization] Support compressed tensors w8a8 static and w8a8 dynamic weight (#4036)
### What this PR does / why we need it?
While using the LLM Compressor quantization tool from the VLLM community
to generate quantized weights, the VLLM Ascend engine needs to be
adapted to support the compressed tensors quantization format.
1. Add AscendCompressedTensorsConfig to replace CompressedTensorsConfig
in vllm.
2. Support CompressedTensorsW8A8 static weight.
- weight: per-channel, int8, symmetric; activation: per-tensor, int8,
symmetric.
4. Support CompressedTensorsW8A8Dynamic weight.
- weight: per-channel, int8, symmetric; activation: per-token, int8,
symmetric, dynamic.
5. Modify the override_quantization_method in AscendQuantConfig.
Co-authored-by: taoqun110 taoqun@huawei.com
Co-authored-by: chenxi-hh chen464822955@163.com
- vLLM version: v0.11.2
---------
Signed-off-by: LHXuuu <scut_xlh@163.com>
Signed-off-by: chenxi-hh <chen464822955@163.com>
Signed-off-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
Co-authored-by: chenxi-hh <chen464822955@163.com>
Co-authored-by: chenxi-hh <32731611+chenxi-hh@users.noreply.github.com>
2025-11-28 14:09:39 +08:00
|
|
|
self.ascend_config.packed_modules_mapping, linear_layer)
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
def test_get_quant_method_for_attention(self):
|
|
|
|
|
attention_layer = MagicMock(spec=Attention)
|
2025-09-19 20:50:14 +08:00
|
|
|
mock_config = MagicMock()
|
|
|
|
|
mock_config.model_config.hf_config.model_type = None
|
|
|
|
|
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
|
|
|
|
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
|
2025-07-02 16:40:51 +08:00
|
|
|
return_value=MagicMock()) as mock_ascend_kvcache:
|
|
|
|
|
# Test with fa_quant_type
|
|
|
|
|
method = self.ascend_config.get_quant_method(
|
|
|
|
|
attention_layer, ".attn")
|
|
|
|
|
self.assertIs(method, mock_ascend_kvcache.return_value)
|
|
|
|
|
|
2025-09-19 20:50:14 +08:00
|
|
|
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
|
|
|
|
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
|
2025-07-02 16:40:51 +08:00
|
|
|
return_value=MagicMock()) as mock_ascend_kvcache:
|
|
|
|
|
# Test with kv_quant_type
|
|
|
|
|
modified_config = {"kv_quant_type": "C8"}
|
|
|
|
|
config = AscendQuantConfig(modified_config)
|
|
|
|
|
config.packed_modules_mapping = None
|
|
|
|
|
method = config.get_quant_method(attention_layer, "attn")
|
|
|
|
|
self.assertIs(method, mock_ascend_kvcache.return_value)
|
|
|
|
|
|
|
|
|
|
def test_get_quant_method_for_fused_moe(self):
|
|
|
|
|
fused_moe_layer = MagicMock(spec=FusedMoE)
|
[V1] MTP supports torchair (#2145)
### What this PR does / why we need it?
Support MTP with:
- [x] V0 Scheduler
- [x] TorchAir
- [x] Single DP
- [x] Multi DP
- [x] Disaggregate PD
Known issues:
- [ ] Not support V1 Scheduler (chunked prefill), will be supported in a
few weeks
- [ ] vllm v0.10.0 does not support metrics with `DP > 1` right now,
need to comment out the line 171-175 in file
`vllm/vllm/v1/metrics/loggers.py`
```
if (len(self.engine_indexes) > 1
and vllm_config.speculative_config is not None):
raise NotImplementedError("Prometheus metrics with Spec Decoding "
"with >1 EngineCore per AsyncLLM is not "
"supported yet.")
```
To start an online server with torchair enabled, here is an example:
```
python -m vllm.entrypoints.openai.api_server \
--model="/weights/DeepSeek-R1_w8a8/" \
--trust-remote-code \
--max-model-len 40000 \
--tensor-parallel-size 4 \
--data_parallel_size 4 \
--max-num-seqs 16 \
--no-enable-prefix-caching \
--enable_expert_parallel \
--served-model-name deepseekr1 \
--speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \
--quantization ascend \
--host 0.0.0.0 \
--port 1234 \
--additional-config '{"ascend_scheduler_config":{"enabled":true,"enable_chunked_prefill":false},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]},"enable_weight_nz_layout":true}' \
--gpu_memory_utilization 0.9
```
offline example with torchair enabled
```
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=16, temperature=0)
# Create an LLM.
llm = LLM(
model="/home/data/DeepSeek-R1_w8a8/",
tensor_parallel_size=16,
max_num_seqs=16,
gpu_memory_utilization=0.9,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
trust_remote_code=True,
enforce_eager=False,
max_model_len=2000,
additional_config = {
'torchair_graph_config': {
'enabled': True,
"graph_batch_sizes": [16],
'enable_multistream_shared_expert': False,
},
"ascend_scheduler_config": {
"enabled": True
},
# 'expert_tensor_parallel_size': 16,
}
)
# Generate texts from the prompts.
# llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
# llm.stop_profile()
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.10.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/302962e806e9820643ae25987e8e38ed035e05d3
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
2025-08-06 19:37:43 +08:00
|
|
|
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
2025-08-22 17:09:08 +08:00
|
|
|
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
|
2025-09-19 20:50:14 +08:00
|
|
|
mock_config = MagicMock()
|
|
|
|
|
mock_config.model_config.hf_config.model_type = None
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
# Test skipped layer
|
|
|
|
|
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
2025-09-19 20:50:14 +08:00
|
|
|
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
2025-07-02 16:40:51 +08:00
|
|
|
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
|
|
|
|
method = self.ascend_config.get_quant_method(
|
|
|
|
|
fused_moe_layer, "moe_layer")
|
|
|
|
|
self.assertIs(method, mock_ascend_moe.return_value)
|
|
|
|
|
|
|
|
|
|
# Test quantized layer
|
|
|
|
|
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
2025-09-19 20:50:14 +08:00
|
|
|
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
2025-07-02 16:40:51 +08:00
|
|
|
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
|
|
|
|
method = self.ascend_config.get_quant_method(
|
|
|
|
|
fused_moe_layer, "moe_layer")
|
|
|
|
|
self.assertIs(method, mock_ascend_moe.return_value)
|
|
|
|
|
|
|
|
|
|
def test_is_layer_skipped_ascend(self):
|
|
|
|
|
# Test non-fused layer that should be quantized
|
|
|
|
|
self.assertFalse(self.ascend_config.is_layer_skipped_ascend("layer1"))
|
|
|
|
|
|
|
|
|
|
# Test non-fused layer that should be skipped
|
|
|
|
|
self.assertTrue(self.ascend_config.is_layer_skipped_ascend("layer2"))
|
|
|
|
|
|
|
|
|
|
# Test fused layer
|
|
|
|
|
fused_mapping = {"fused_layer": ["shard1", "shard2"]}
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
self.ascend_config.is_layer_skipped_ascend("fused_layer",
|
|
|
|
|
fused_mapping))
|
|
|
|
|
|
|
|
|
|
# Test inconsistent fused layer shards
|
|
|
|
|
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
|
|
|
|
|
config = AscendQuantConfig(bad_config)
|
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
|
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
|
|
|
|
|
|
|
|
|
|
def test_get_scaled_act_names(self):
|
|
|
|
|
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAscendKVCacheMethod(TestBase):
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
# Setup common test fixtures
|
|
|
|
|
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
2025-09-04 11:35:14 +08:00
|
|
|
self.mock_quant_config.quant_description = {"kv_quant_type": "C8"}
|
|
|
|
|
self.prefix = "layer.attn"
|
2025-07-02 16:40:51 +08:00
|
|
|
|
2025-09-04 11:35:14 +08:00
|
|
|
# Mock quant_method
|
2025-07-02 16:40:51 +08:00
|
|
|
self.mock_quant_method = MagicMock()
|
2025-09-04 11:35:14 +08:00
|
|
|
self.patcher = patch(
|
|
|
|
|
'vllm_ascend.quantization.quant_config.get_quant_method')
|
|
|
|
|
self.mock_get_quant_method = self.patcher.start()
|
|
|
|
|
self.mock_get_quant_method.return_value = self.mock_quant_method
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
# Create instance
|
|
|
|
|
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
|
|
|
|
|
self.prefix)
|
|
|
|
|
|
|
|
|
|
def tearDown(self):
|
2025-09-04 11:35:14 +08:00
|
|
|
self.patcher.stop()
|
2025-07-02 16:40:51 +08:00
|
|
|
|
|
|
|
|
def test_create_weights(self):
|
|
|
|
|
"""Test create_weights delegates to quant_method."""
|
|
|
|
|
mock_layer = MagicMock()
|
|
|
|
|
self.kv_cache_method.create_weights(mock_layer)
|
|
|
|
|
self.mock_quant_method.create_weights.assert_called_once_with(
|
|
|
|
|
mock_layer)
|
|
|
|
|
|
|
|
|
|
def test_process_weights_after_loading_with_method(self):
|
|
|
|
|
"""Test process_weights when quant_method has the method."""
|
|
|
|
|
mock_layer = MagicMock()
|
|
|
|
|
self.kv_cache_method.process_weights_after_loading(mock_layer)
|
|
|
|
|
self.mock_quant_method.process_weights_after_loading.assert_called_once_with(
|
|
|
|
|
mock_layer)
|
|
|
|
|
|
|
|
|
|
def test_process_weights_after_loading_without_method(self):
|
|
|
|
|
"""Test process_weights when quant_method lacks the method."""
|
|
|
|
|
# Reset mock to remove the method
|
|
|
|
|
del self.mock_quant_method.process_weights_after_loading
|
|
|
|
|
mock_layer = MagicMock()
|
|
|
|
|
|
|
|
|
|
# Should not raise exception
|
|
|
|
|
self.kv_cache_method.process_weights_after_loading(mock_layer)
|
|
|
|
|
|
|
|
|
|
def test_apply_delegation(self):
|
|
|
|
|
"""Test apply properly delegates to quant_method."""
|
|
|
|
|
mock_layer = MagicMock()
|
|
|
|
|
mock_query = torch.randn(1, 32, 128)
|
|
|
|
|
mock_key = torch.randn(1, 32, 128)
|
|
|
|
|
mock_value = torch.randn(1, 32, 128)
|
|
|
|
|
mock_kv_cache = MagicMock()
|
|
|
|
|
mock_attn_metadata = MagicMock()
|
|
|
|
|
mock_scale = 1.0
|
|
|
|
|
mock_output = torch.zeros(1, 32, 128)
|
|
|
|
|
mock_attn_type = MagicMock()
|
|
|
|
|
expected_result = torch.randn(1, 32, 128)
|
|
|
|
|
self.mock_quant_method.apply.return_value = expected_result
|
|
|
|
|
|
|
|
|
|
result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key,
|
|
|
|
|
mock_value, mock_kv_cache,
|
|
|
|
|
mock_attn_metadata, mock_attn_type,
|
|
|
|
|
mock_scale, mock_output)
|
|
|
|
|
|
|
|
|
|
self.mock_quant_method.apply.assert_called_once_with(
|
|
|
|
|
mock_layer, mock_query, mock_key, mock_value, mock_kv_cache,
|
|
|
|
|
mock_attn_metadata, mock_attn_type, mock_scale, mock_output)
|
|
|
|
|
self.assertTrue(torch.equal(result, expected_result))
|