diff --git a/docs/source/user_guide/feature_guide/layer_sharding.md b/docs/source/user_guide/feature_guide/layer_sharding.md index 3d7bc160..10844bbb 100644 --- a/docs/source/user_guide/feature_guide/layer_sharding.md +++ b/docs/source/user_guide/feature_guide/layer_sharding.md @@ -1,8 +1,6 @@ ---- -title: Layer Sharding Guide ---- +# Layer Sharding Linear Guide -# Overview +## Overview **Layer Shard Linear** is a memory-optimization feature designed for large language model (LLM) inference. It addresses the high memory pressure caused by **repeated linear operators across many layers** that share identical structure but have distinct weights. @@ -19,14 +17,14 @@ This approach **preserves exact computational semantics** while **significantly --- -## Flowchart +### Flowchart ![layer shard](./images/layer_sharding.png) > **Figure.** Layer Shard Linear workflow: weights are sharded by layer across devices (top), and during forward execution (bottom), asynchronous broadcast pre-fetches the next layer's weight while the current layer computes—enabling zero-overhead weight loading. --- -# Getting Started +## Getting Started To enable **Layer Shard Linear**, specify the target linear layers using the `--additional-config` argument when launching your inference job. For example, to shard the `o_proj` and `q_b_proj` layers, use: @@ -38,11 +36,11 @@ To enable **Layer Shard Linear**, specify the target linear layers using the `-- --- -# Supported Scenarios +## Supported Scenarios This feature can be enabled in any scenario, but delivers the greatest benefit in the following cases: -## FlashComm2-enabled +### FlashComm2-enabled When using [FlashComm2](https://github.com/vllm-project/vllm-ascend/pull/4188), the full output projection (`o_proj`) matrix must be resident in memory for each layer. Layer sharding significantly reduces memory pressure by distributing these weights across devices. @@ -57,7 +55,7 @@ vllm serve \ }' ``` -## DSA-CP-enabled +### DSA-CP-enabled With [DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702), both `q_b_proj` and `o_proj` layers require large weight matrices to be stored per layer. Sharding these layers across NPUs helps fit extremely deep models (e.g., 61-layer architectures) into limited device memory.