diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 20b8910..4401ecc 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -85,7 +85,7 @@ def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8, and getattr(vllm_config.speculative_config, 'method', None) in ("eagle", "eagle3")) - if dtype in [torch.float16, torch.bfloat16]: + if dtype in [torch.float16, torch.bfloat16, torch.float32]: return _ENABLE_NZ if _IS_EAGLE_MODE else False return _ENABLE_NZ