diff --git a/docs/source/tutorials/hardwares/310p.md b/docs/source/tutorials/hardwares/310p.md index 429474d7..c96d2c5e 100644 --- a/docs/source/tutorials/hardwares/310p.md +++ b/docs/source/tutorials/hardwares/310p.md @@ -1,21 +1,58 @@ -# Atlas 300I +# Atlas 300I DUO -```{note} -1. This Atlas 300I series is currently experimental. In future versions, there may be behavioral changes related to model coverage and performance improvement. -2. Currently, the Atlas 300I series only supports eager mode and the float16 data type. +## Running vLLM on Atlas 300I DUO + +### Notes + +* The current release supports `FULL_DECODE_ONLY` graph mode on Atlas 300I DUO devices, but the following limitations apply due to hardware event-id resource constraints: + + * When multiple Tensor Parallel (TP) ranks are enabled, the number of capturable graphs is limited and depends on the model depth. For example, Qwen3-32B can capture and replay 2 graphs. + * There is no such limitation when TP=1. + * We have reached out to the relevant experts for a solution. A software-based fix is considered feasible, but full support will take additional time. Thank you for your understanding. + +* Atlas 300I DUO does not support `triton` or `triton-ascend`. + +* If installing from source, `vllm` and `vllm-ascend` will automatically pull in `triton` and `triton-ascend` dependencies, which may cause unexpected issues on Atlas 300I DUO. Please run: + +```bash +pip uninstall -y triton && triton-ascend +# If you still encounter errors mentioning triton, manually remove the remaining triton directory in site-packages, +# as uninstalling triton may leave residual files behind. +# For example: rm -rf /usr/local/python3.11.10/lib/python3.11/site-packages/triton ``` -## Run vLLM on Atlas 300I Series +### Deployment -Run docker container: +```{warning} +For Atlas 300I DUO (310P), do not rely on automatic `max-model-len` detection +(that is, do not omit the `--max-model-len` argument), or OOM may occur. + +Reason (current 310P attention path): +- `AscendAttentionMetadataBuilder310` passes `model_config.max_model_len` + to `AttentionMaskBuilder310`. +- `AttentionMaskBuilder310` builds a full float16 causal mask with shape + `[max_model_len, max_model_len]`, + and then converts it to FRACTAL_NZ format. +- In the 310P `attention_v1` prefill/chunked-prefill path + (`_npu_flash_attention` / `_npu_paged_attention_splitfuse`), + this explicit mask tensor is used directly, and there is currently + no compressed-mask path. + +If automatic parsing resolves to a large context length, allocating this mask +(`O(max_model_len^2)`) may exceed NPU memory and trigger OOM. +Be sure to set an explicit and conservative value, such as `--max-model-len 16384`. +``` + +Run the Docker container: ```{code-block} bash :substitutions: -# Update the vllm-ascend image -export IMAGE=quay.io/ascend/vllm-ascend:v0.10.0rc1-310p + +# Use the vllm-ascend image +export IMAGE=quay.io/ascend/vllm-ascend:v0.18.0rc1-310p docker run --rm \ --name vllm-ascend \ ---shm-size=1g \ +--shm-size=10g \ --device /dev/davinci0 \ --device /dev/davinci1 \ --device /dev/davinci2 \ @@ -37,282 +74,136 @@ docker run --rm \ -it $IMAGE bash ``` -Set up environment variables: +Run the following steps to start the vLLM service on NPU for the Qwen3 Dense series: -```bash -# Load model from ModelScope to speed up download -export VLLM_USE_MODELSCOPE=True +* Prepare the environment -# Set `max_split_size_mb` to reduce memory fragmentation and avoid out of memory -export PYTORCH_NPU_ALLOC_CONF=max_split_size_mb:256 -``` + * Obtain model weights + (`W8A8SC` weights will be uploaded to the Eco-Tech official ModelScope repository later.) -### Online Inference on NPU + * This guide requires `W8A8SC` quantized weights for the Qwen3 Dense `8B/14B/32B` models. You need to generate the SC-compressed weights yourself. + * First, prepare the `W8A8S` weights: -```{warning} -For Atlas 300I (310P), do not rely on `max-model-len` auto detection -(omit `--max-model-len`), because it may cause OOM. + * Qwen3-8B-w8a8s-310: [https://modelers.cn/models/Eco-Tech/Qwen3-8B-w8a8s-310](https://modelers.cn/models/Eco-Tech/Qwen3-8B-w8a8s-310) + * Qwen3-14B-w8a8s-310: [https://modelers.cn/models/Eco-Tech/Qwen3-14B-w8a8s-310](https://modelers.cn/models/Eco-Tech/Qwen3-14B-w8a8s-310) + * Qwen3-32B-w8a8s-310: [https://modelers.cn/models/Eco-Tech/Qwen3-32B-w8a8s-310](https://modelers.cn/models/Eco-Tech/Qwen3-32B-w8a8s-310) -Reason (current 310P attention path): -- `AscendAttentionMetadataBuilder310` passes `model_config.max_model_len` - to `AttentionMaskBuilder310`. -- `AttentionMaskBuilder310` builds a full causal mask with shape - `[max_model_len, max_model_len]` in float16, then casts it to FRACTAL_NZ. -- In 310P `attention_v1` prefill/chunked-prefill - (`_npu_flash_attention` / `_npu_paged_attention_splitfuse`), - this explicit mask tensor is consumed directly, and there is no - compressed-mask path. + Note: if you want to validate directly with `w8a8s` weights instead of `w8a8sc` weights, the following example shows the serving command for `Qwen3-8B-w8a8s-310`. Performance is slightly lower than with compressed `w8a8sc` weights. Detailed `w8a8sc` testing is covered in the following sections. -So if auto resolves to a large context length, the mask allocation -(`O(max_model_len^2)`) can exceed NPU memory and trigger OOM. -Always set a conservative explicit value, for example `--max-model-len 4096`. -``` + ```bash + vllm serve Eco-Tech/Qwen3-8B-w8a8s-310 --host 127.0.0.1 --port 8080 \ + --tensor-parallel-size 1 --gpu_memory_utilization 0.90 \ + --served_model_name qwen --dtype float16 \ + --additional-config '{"ascend_compilation_config": {"fuse_norm_quant": false}}' \ + --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1,2,4,8,16,32]}' \ + --quantization ascend --max_model_len 16384 + # `--load_format` is required only for the W8A8SC quantized weight format. + # + ``` -Run the following script to start the vLLM server on NPU (Qwen3-0.6B:1 card, Qwen2.5-7B-Instruct:2 cards, Pangu-Pro-MoE-72B: 8 cards): + * Compress the weights -:::::{tab-set} -:sync-group: inference + * Uninstall triton (unsupported on 310P): -::::{tab-item} Qwen3-0.6B -:selected: -:sync: qwen0.6 + ```bash + pip uninstall triton + pip uninstall triton-ascend + ``` -Run the following command to start the vLLM server: + * Get the compression script: -```{code-block} bash - :substitutions: -vllm serve Qwen/Qwen3-0.6B \ - --tensor-parallel-size 1 \ - --max-model-len 4096 \ - --enforce-eager \ - --dtype float16 -``` + * [https://github.com/vllm-project/vllm-ascend/blob/main/examples/save_sharded_state_310.py](https://github.com/vllm-project/vllm-ascend/blob/main/examples/save_sharded_state_310.py) -Once your server is started, you can query the model with input prompts. + * Install the compression tool -```bash -curl http://localhost:8000/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "The future of AI is", - "max_completion_tokens": 64, - "top_p": 0.95, - "top_k": 50, - "temperature": 0.6 - }' -``` + * Repository: [https://gitcode.com/Ascend/msit.git](https://gitcode.com/Ascend/msit.git) + * Installation guide: [https://gitcode.com/Ascend/msit/blob/master/msmodelslim/docs/安装指南.md#基于atlas-300i-duo-系列产品安装](https://gitcode.com/Ascend/msit/blob/master/msmodelslim/docs/安装指南.md#基于atlas-300i-duo-系列产品安装) -:::: + * Compression command -::::{tab-item} Qwen2.5-7B-Instruct -:sync: qwen7b + ```bash + export PYTORCH_NPU_ALLOC_CONF=max_split_size_mb:256 + export LD_LIBRARY_PATH=/usr/local/python3.11.10/lib/:$LD_LIBRARY_PATH -Run the following command to start the vLLM server: + python save_sharded_state_310.py \ + --model /your-load-path/w8a8s-weight \ + --tensor-parallel-size 1 \ + --output /your-save-path/w8a8sc-weight \ + --enable-compress \ + --compress-process-num 4 \ + --enforce-eager \ + --dtype float16 \ + --quantization ascend \ + --max_model_len 10240 + ``` -```{code-block} bash - :substitutions: -vllm serve Qwen/Qwen2.5-7B-Instruct \ - --tensor-parallel-size 2 \ - --max-model-len 4096 \ - --enforce-eager \ - --dtype float16 -``` + Argument notes: `--tensor-parallel-size`: `W8A8SC` quantized weights are tightly coupled to the TP size, so you must specify the TP size you plan to use at serving time when running compression. `--model` is the path to the input `w8a8s` weights, and `--output` is the output path for the compressed `w8a8sc` weights. -Once your server is started, you can query the model with input prompts. + * Additional notes -```bash -curl http://localhost:8000/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "The future of AI is", - "max_completion_tokens": 64, - "top_p": 0.95, - "top_k": 50, - "temperature": 0.6 - }' -``` + * The Qwen3-8B model has fewer parameters, so some layers need fallback handling during quantization. It is recommended to download the `qwen3-8B-w8a8sc` weights directly from the Eco-Tech official ModelScope repository once available. -:::: +* Examples -::::{tab-item} Qwen2.5-VL-3B-Instruct -:sync: qwen-vl-2.5-3b + * Qwen3-8B-w8a8sc example -Run the following command to start the vLLM server: + ```bash + vllm serve /your-save-path/Qwen3-8B-w8a8sc-310-vllm/TP1/Qwen3-8B-w8a8sc-310-vllm-tp1/ \ + --host 127.0.0.1 \ + --port 8080 \ + --tensor-parallel-size 1 \ + --gpu_memory_utilization 0.90 \ + --max_num_seqs 32 \ + --served_model_name qwen \ + --dtype float16 \ + --additional-config '{"ascend_compilation_config": {"fuse_norm_quant": false}}' \ + --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1,2,4,8,16,32]}' \ + --quantization ascend \ + --max_model_len 16384 \ + --no-enable-prefix-caching \ + --load_format="sharded_state" + ``` -```{code-block} bash - :substitutions: -vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ - --tensor-parallel-size 1 \ - --max-model-len 4096 \ - --enforce-eager \ - --dtype float16 -``` + * Qwen3-14B-w8a8sc example -Once your server is started, you can query the model with input prompts. + ```bash + vllm serve /your-save-path/Qwen3-14B-w8a8sc-310-vllm/TP1/Qwen3-14B-w8a8sc-310-vllm-tp1/ \ + --host 127.0.0.1 \ + --port 8080 \ + --tensor-parallel-size 1 \ + --gpu_memory_utilization 0.90 \ + --max_num_seqs 16 \ + --served_model_name qwen \ + --dtype float16 \ + --additional-config '{"ascend_compilation_config": {"fuse_norm_quant": false}}' \ + --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [1,2,4,8,16]}' \ + --quantization ascend \ + --max_model_len 16384 \ + --no-enable-prefix-caching \ + --load_format="sharded_state" + ``` -```bash -curl http://localhost:8000/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "The future of AI is", - "max_completion_tokens": 64, - "top_p": 0.95, - "top_k": 50, - "temperature": 0.6 - }' -``` + * Qwen3-32B-w8a8sc example -:::: -::::: + ```bash + export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 -If you run this script successfully, you can see the results. + vllm serve /save-path/Qwen3-32B-w8a8sc-310-vllm/TP4/Qwen3-32B-w8a8sc-310-vllm-tp4/ \ + --host 127.0.0.1 \ + --port 8080 \ + --tensor-parallel-size 4 \ + --gpu_memory_utilization 0.90 \ + --max_num_seqs 32 \ + --served_model_name qwen \ + --dtype float16 \ + --additional-config '{"ascend_compilation_config": {"fuse_norm_quant": false}}' \ + --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY", "cudagraph_capture_sizes": [16,32]}' \ + --quantization ascend \ + --max_model_len 20480 \ + --no-enable-prefix-caching \ + --load_format="sharded_state" + ``` -### Offline Inference +* Closing notes -Run the following script (`example.py`) to execute offline inference on NPU: - -:::::{tab-set} -:sync-group: inference - -::::{tab-item} Qwen3-0.6B -:selected: -:sync: qwen0.6 - -```{code-block} python - :substitutions: -import gc -import torch -from vllm import LLM, SamplingParams -from vllm.distributed.parallel_state import (destroy_distributed_environment, - destroy_model_parallel) - -def clean_up(): - destroy_model_parallel() - destroy_distributed_environment() - gc.collect() - torch.npu.empty_cache() -prompts = [ - "Hello, my name is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(max_completion_tokens=100, temperature=0.0) -# Create an LLM. -llm = LLM( - model="Qwen/Qwen3-0.6B", - tensor_parallel_size=1, - max_model_len=4096, - enforce_eager=True, # For 300I series, only eager mode is supported. - dtype="float16", # IMPORTANT: Some ATB ops do not support bf16 on the 300I series. -) -# Generate texts from the prompts. -outputs = llm.generate(prompts, sampling_params) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -del llm -clean_up() -``` - -:::: - -::::{tab-item} Qwen2.5-7B-Instruct -:sync: qwen7b - -```{code-block} python - :substitutions: -from vllm import LLM, SamplingParams -import gc -import torch -from vllm import LLM, SamplingParams -from vllm.distributed.parallel_state import (destroy_distributed_environment, - destroy_model_parallel) - -def clean_up(): - destroy_model_parallel() - destroy_distributed_environment() - gc.collect() - torch.npu.empty_cache() -prompts = [ - "Hello, my name is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(max_completion_tokens=100, temperature=0.0) -# Create an LLM. -llm = LLM( - model="Qwen/Qwen2.5-7B-Instruct", - tensor_parallel_size=2, - max_model_len=4096, - enforce_eager=True, # For 300I series, only eager mode is supported. - dtype="float16", # IMPORTANT: Some ATB ops do not support bf16 on the 300I series. -) -# Generate texts from the prompts. -outputs = llm.generate(prompts, sampling_params) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -del llm -clean_up() -``` - -:::: - -::::{tab-item} Qwen2.5-VL-3B-Instruct -:sync: qwen-vl-2.5-3b - -```{code-block} python - :substitutions: -from vllm import LLM, SamplingParams -import gc -import torch -from vllm import LLM, SamplingParams -from vllm.distributed.parallel_state import (destroy_distributed_environment, - destroy_model_parallel) - -def clean_up(): - destroy_model_parallel() - destroy_distributed_environment() - gc.collect() - torch.npu.empty_cache() -prompts = [ - "Hello, my name is", - "The future of AI is", -] -# Create a sampling params object. -sampling_params = SamplingParams(max_completion_tokens=100, top_p=0.95, top_k=50, temperature=0.6) -# Create an LLM. -llm = LLM( - model="Qwen/Qwen2.5-VL-3B-Instruct", - tensor_parallel_size=1, - max_model_len=4096, - enforce_eager=True, # For 300I series, only eager mode is supported. - dtype="float16", # IMPORTANT: Some ATB ops do not support bf16 on the 300I series. -) -# Generate texts from the prompts. -outputs = llm.generate(prompts, sampling_params) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -del llm -clean_up() -``` - -:::: -::::: - -Run script: - -```bash -python example.py -``` - -If you run this script successfully, you can see the info shown below: - -```bash -Prompt: 'Hello, my name is', Generated text: " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the US. I want to know if there are any opportunities in the US for me to work. I'm also interested in the culture and lifestyle in the US. I want to know if there are any opportunities for me to work in the US. I'm also interested in the culture and lifestyle in the US. I'm interested in the culture" -Prompt: 'The future of AI is', Generated text: " not just about the technology itself, but about how we use it to solve real-world problems. As AI continues to evolve, it's important to consider the ethical implications of its use. AI has the potential to bring about significant changes in society, but it also has the power to create new challenges. Therefore, it's crucial to develop a comprehensive approach to AI that takes into account both the benefits and the risks associated with its use. This includes addressing issues such as bias, privacy, and accountability." -``` + For early access to Qwen3-MoE, Qwen3-VL, and preview support for Qwen3.5 and Qwen3.6 with performance acceleration, follow #7394 for updated deployment guidance.