Scalax is a collection of utilties for helping developers to easily scale up JAX based machine learning models. The main idea of scalax is pretty simple: users write model and training code for a single GPU/TPU, and rely on scalax to automatically scale it up to hundreds of GPUs/TPUs. This is made possible by the JAX jit compiler, and scalax provides a set of utilities to help the users obtain the sharding annotations required by the jit compiler. Because scalax wraps around the jit compiler, existing JAX code can be easily scaled up using scalax with minimal changes.
Scalax came out of our experience building EasyLM, a scalable language model training library built on top of JAX.
Scalax is available on PyPI and can be installed using pip:
pip install scalax
More details about scalax can be found in the documentation page.
We are running an unofficial Discord community (unaffiliated with Google) for discussion related to training large models in JAX. Follow this link to join the Discord server. We have dedicated channel for scalax.
We provide a set of well annotated examples in the examples directory. The most notable ones include:
- MLP with Data Parallelism
- LLaMA with fully sharded data parallelism
- LLaMA with combined fully sharded data parallelism and tensor parallelism
Suppose we have a simple flax model and train step function:
class Model(nn.Module):
...
def train_step(train_state, batch):
...
return updated_train_state, metrics
Typically, we would use jax.jit to compile the train_step function in order to accelerate the training:
@jax.jit
def train_step(train_state, batch):
...
return updated_train_state, metrics
This works fine for a single GPU/TPU, but if we want to scale up to multiple
GPU/TPUs, we need to partition the data or the model in order to parallelize
the training across devices. Fortunately, JAX JIT already provides a way to
handle these partitions with sharding annotations.
For example, if we have sharding annotations for the train_state
and batch
pytree, we can simply JIT compile the train_step function with these sharding
annotations:
@partial(
jax.jit,
in_shardings=(train_state_shardings, batch_sharding), # Shard the train_state
out_shardings=(train_state_shardings, None),
)
def train_step(train_state, batch):
...
return updated_train_state, metrics
The train_state_shardings
and batch_sharding
are pytrees having the same
structure as the train_state
and batch
pytrees, but with jax.sharding.Sharding
objects at the leaf nodes. These sharding objects are tied to the physical device
mesh and are often difficult to construct, especially for complex models and
training code. This is where scalax comes in. Scalax provides a set of utilities
to help the users automatically obtain the sharding annotations, without having
to worry about the underlying pytree structure. Scalax handles this by abstracting
away the concrete sharding objects and using a ShardingRule
object instead. A
ShardingRule
object can generate the sharding annotations for any given pytree
according to its internal rules.
For example, scalax provides a FSDPShardingRule
object, which can automatically
generate sharding annotations for a pytree according to the Fully Sharded Data
Parallelism (FSDP) strategy. To apply it to our train_step
function, we can
simply replace the jax.jit
decorator:
from functools import partial
from scalax.sharding import MeshShardingHelper, PartitionSpec, FSDPShardingRule
mesh = MeshShardingHelper([-1], ['fsdp']) # Create a 1D mesh with fsdp axis
@partial(
mesh.sjit,
in_shardings=(FSDPShardingRule(), None), # Shard the train_state using FSDP
out_shardings=(FSDPShardingRule(), None),
args_sharding_constraint=(FSDPShardingRule(), PartitionSpec('fsdp')),
)
def train_step(train_state, batch):
...
return updated_train_state, metrics
In the previous example, we see that scalax provides a MeshShardingHelper
object
using a 1D device mesh with a fsdp axis. We then use the sjit
method to compile
the train_step
function with the FSDP sharding rules, without having to worry
about the specific underlying pytree structure of the train_state
. Beyond
FSDP, scalax also provides TreePathShardingRule
and PolicyShardingRule
, which
allows users to easily combine different sharding strategies such as replicated
data parallelism, FSDP, tensor parallelism and sequence parallelism to best fit
their model and training setup. All of these can be done with minimal changes to
the original model and training code. This makes it easy to integrate scalax
into existing JAX codebases.
Scalax currently supports the following sharding rules:
FSDPShardingRule
: A sharding rule for automatically selecting an axis for Fully Sharded Data Parallelism (FSDP).TreePathShardingRule
: A regular expression sharding rule for sharding a pytree according to the path of its leaves.PolicyShardingRule
: A sharding rule which determins the sharding according to a user defined callable policy.
In previous example, we see that scalax sjit
can help us easily shard the
input and output of a jitted function. In many cases, this would be sufficient
to scale up the training, as the intermdiate tensors are automatically sharded
by XLA. However, in some cases, XLA might not be able to derive the optimal
sharding for the intermediate tensors, and we might want to manually specify
the sharding for these tensors. Similar to JAX, scalax provides a
with_sharding_constraint
function to manually specify the sharding.
Similar to sjit
, with_sharding_constraint
takes both ShardingRule
and
PartitionSpec
objects.
from scalax.sharding import MeshShardingHelper, PartitionSpec, FSDPShardingRule
from scalax.sharding import with_sharding_constraint
mesh = MeshShardingHelper([-1], ['fsdp']) # Create a 1D mesh with fsdp axis
@partial(
mesh.sjit,
in_shardings=(FSDPShardingRule(), None), # Shard the train_state using FSDP
out_shardings=(FSDPShardingRule(), None),
args_sharding_constraint=(FSDPShardingRule(), PartitionSpec('fsdp')),
)
def train_step(train_state, batch):
...
intermediate_pytree = ...
intermediate_pytree = with_sharding_constraint(
intermediate_pytree, FSDPShardingRule(),
)
...
return updated_train_state, metrics
In previous example, we apply the FSDPShardingRule
to the intermediate_pytree
.
However, this way of sharding intermediate tensors is intrusive to the original
training code. To make it easier to shard intermediate tensors, scalax provides
a with_sharding_annotation
function, which only register a name for the sharding
within the training code without tieing it to a concate sharding rule. This allows
the same model and training code to be sharded differently without changing the
code. For example:
from scalax.sharding import MeshShardingHelper, PartitionSpec, FSDPShardingRule
from scalax.sharding import with_sharding_annotation
mesh = MeshShardingHelper([-1], ['fsdp']) # Create a 1D mesh with fsdp axis
@partial(
mesh.sjit,
in_shardings=(FSDPShardingRule(), None), # Shard the train_state using FSDP
out_shardings=(FSDPShardingRule(), None),
args_sharding_constraint=(FSDPShardingRule(), PartitionSpec('fsdp')),
annotation_shardings={
'weights': FSDPShardingRule(),
'activations': PartitionSpec('fsdp'),
}
)
def train_step(train_state, batch):
...
weights_pytree = with_sharding_annotation(
weights_pytree, 'weights',
)
activations = with_sharding_annotation(
activations, 'activations',
)
...
return updated_train_state, metrics