This project demonstrates how to implement pipeline parallelism for large language models using MLX. It includes tools for sharding a model, serving shards across multiple machines, and generating text using the distributed model. Additionally, it features an OpenAI API-compatible server for easier integration and usage.
To see the distributed inference in action, check out our demo video:
Sharding DeepSeek-Coder-V2-Lite-Instruct Demo
Install the package using pip:
pip install mlx-sharding
-
For the shard node:
mlx-sharding-server --model mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx --start-layer 14 --end-layer 27
-
For the primary node:
mlx-sharding-api --model mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx --start-layer 0 --end-layer 14 --llm-shard-addresses <your shard node address>
Replace
<your shard node address>
with the actual address of your shard node (e.g.,localhost:50051
).
This repository is designed for educational purposes to illustrate how pipeline parallelism can be implemented in MLX. It provides a basic framework for:
- Sharding a large language model
- Distributing model shards across multiple machines
- Implementing a simple pipeline for text generation
- Serving the model through an OpenAI API-compatible interface
While not optimized for production use, this demo serves as a starting point for understanding and experimenting with pipeline parallelism in machine learning workflows.
You have two main options for preparing and using the model:
If you prefer to pre-shard the model, use sharding_weight.py
:
python sharding_weight.py --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --output_dir shard_0 --start_layer 0 --end_layer 14 --total_layers 27
python sharding_weight.py --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --output_dir shard_1 --start_layer 14 --end_layer 27 --total_layers 27
# Repeat for additional shards as needed
You can let the system dynamically load and shard the weights when starting the server. This option doesn't require pre-sharding.
If you've pre-sharded the model, copy the shard directories to their respective machines. Skip this step for Option B.
Start server instances based on your chosen approach:
On each machine with a shard, start a server instance. For example:
python -m shard.main --model mzbac/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx-shard-1
Start the server with specific layer ranges:
python -m shard.main --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --start-layer 0 --end-layer 14
Note the IP address and port printed by each server.
For a dynamically sharded setup:
python generate.py --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --start_layer 0 --end_layer 14 --server_address <remote_ip1>:<port1>,<remote_ip2>:<port2> --prompt "Your prompt here" --max_tokens 512
For a pre-sharded setup:
python generate.py --model mzbac/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx-shard-0 --server_address <remote_ip1>:<port1>,<remote_ip2>:<port2> --prompt "Your prompt here" --max_tokens 512
-
Start the server:
For dynamic sharding:
python -m shard.openai_api --model "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" --llm-shard-addresses localhost:50051,<remote_ip1>:<port1>,<remote_ip2>:<port2> --start-layer 0 --end-layer 14
For pre-sharded model:
python -m shard.openai_api --model mzbac/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx-shard-0 --llm-shard-addresses localhost:50051,<remote_ip1>:<port1>,<remote_ip2>:<port2>
-
Use the API endpoints:
/v1/completions
: Text completion endpoint/v1/chat/completions
: Chat completion endpoint
Example usage:
curl localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Say this is a test!"}],
"temperature": 0.7
}'
This project now includes a web-based user interface for easy interaction with the model. To use the UI:
-
Ensure the OpenAI API-compatible server is running (as described in step 4).
-
Navigate to
http://localhost:8080
(or the appropriate host and port if you've configured it differently) in your web browser. -
Use the interface to input prompts, adjust parameters, and view the model's responses.
The UI provides a user-friendly way to interact with the model, making it easier to experiment with different inputs and settings without needing to use command-line tools or write code.
-
Network Dependency: The performance of this pipeline parallelism implementation is heavily dependent on network speed and latency between machines.
-
Error Handling: The current implementation has basic error handling. In a production environment, you'd want to implement more robust error handling and recovery mechanisms.
-
Security: This demo uses insecure gRPC channels. For any real-world application, implement proper security measures.
-
Shard Configuration: Ensure that when using multiple shards, the layer ranges are set correctly to cover the entire model without overlap.
To extend the system for more shards:
- If pre-sharding, create additional shards using
sharding_weight.py
. - Set up more server instances, one for each new shard.
- In
generate.py
or when using the OpenAI API server, include all shard addresses. - Adjust the layer ranges accordingly when using dynamic sharding.
- Python 3.x
- MLX library
- gRPC and related dependencies
- NumPy
- Transformers library
- Sufficient RAM on each machine to load and process its model shard
- MLX team for providing the framework
- Exo(https://github.com/exo-explore/exo) that I heavily inspired from for their implementation of pipeline parallelism