diff --git a/requirements-dev.txt b/requirements-dev.txt index d4a5acd..cbd851e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,6 @@ pytest >= 6.0 pytest-asyncio pytest-mock lm-eval -ray types-jsonschema xgrammar zmq @@ -14,3 +13,5 @@ types-psutil pytest-cov regex sentence_transformers +ray>=2.47.1 +protobuf==4.25.6 diff --git a/tests/e2e/multicard/test_pipeline_parallel.py b/tests/e2e/multicard/test_pipeline_parallel.py index c0c757a..612744e 100644 --- a/tests/e2e/multicard/test_pipeline_parallel.py +++ b/tests/e2e/multicard/test_pipeline_parallel.py @@ -24,6 +24,7 @@ MODELS = [ TENSOR_PARALLELS = [2] PIPELINE_PARALLELS = [2] +DIST_EXECUTOR_BACKEND = ["mp", "ray"] prompts = [ "Hello, my name is", @@ -34,10 +35,13 @@ prompts = [ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tp_size", TENSOR_PARALLELS) @pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS) -def test_models(model: str, tp_size: int, pp_size: int) -> None: +@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND) +def test_models(model: str, tp_size: int, pp_size: int, + distributed_executor_backend: str) -> None: with VllmRunner(model, tensor_parallel_size=tp_size, pipeline_parallel_size=pp_size, + distributed_executor_backend=distributed_executor_backend, enforce_eager=True, gpu_memory_utilization=0.7) as vllm_model: vllm_model.generate_greedy(prompts, 64) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 51fbae2..65e271c 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -400,19 +400,13 @@ class TestAscendAttentionBackendImpl(TestBase): layer = self.layer_no_quant mock_vanilla_prefill.return_value = MagicMock() - def mock_tensor(data, device=None, **kwargs): - if device == "npu": - return metadata.attn_mask - return torch.tensor(data, **kwargs) - - with patch("torch.tensor", side_effect=mock_tensor): - output = self.impl_192.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl_192.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) mock_vanilla_prefill.assert_called_once() assert output.shape == (10, 8 * 192) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index b0e9f3b..9f2bd9b 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -396,8 +396,10 @@ class AscendAttentionBackendImpl(AttentionImpl): if self.head_size == 192: cu_seqlen_q = [0] + attn_metadata.query_lens.tolist() cu_seqlen_k = [0] + attn_metadata.seq_lens.tolist() - cu_seqlen_q = torch.tensor(cu_seqlen_q, device="npu") - cu_seqlen_k = torch.tensor(cu_seqlen_k, device="npu") + cu_seqlen_q = torch.tensor(cu_seqlen_q, + device=query.device) + cu_seqlen_k = torch.tensor(cu_seqlen_k, + device=query.device) cu_seqlen_q = torch.cumsum(cu_seqlen_q, dim=0) cu_seqlen_k = torch.cumsum(cu_seqlen_k, dim=0) max_seqlen_q = torch.max(attn_metadata.query_lens) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a3db5fd..8212c36 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -233,7 +233,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.spec_attn_mask = torch.triu(torch.ones(2048, 2048, dtype=torch.bool), - diagonal=1).to("npu") + diagonal=1).to(self.device) if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) @@ -1120,6 +1120,19 @@ class NPUModelRunner(LoRAModelRunnerMixin): input_ids = self.input_ids[:padded_batch_size] positions = self.positions[:padded_batch_size] + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + assert intermediate_tensors is not None + assert self.intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + self.intermediate_tensors[k][:num_input_tokens].copy_( + v[:num_input_tokens], non_blocking=True) + intermediate_tensors = IntermediateTensors({ + k: v[:num_input_tokens] + for k, v in self.intermediate_tensors.items() + }) + # Run forward pass with set_forward_context(attn_metadata, self.vllm_config,