Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer offloading through weight-only offload #867

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

hanzhi713
Copy link
Member

This PR requires jax >= 0.4.34, != 0.4.35, >=0.4.36: it works on jax 0.4.34, but is broken on jax 0.4.35 due to libtpu bug. It worked on nightly jax 0.4.36 as of 10/30.

This PR represents effort to enable optimizer offloading. The approach we use in this PR is weight-only offloading, which is based on similar building blocks as activation offloading (aka remat offload). When offloading is enabled, optimizer states are stored on CPU pinned memory. Before apply optimizer to calculate updates, optimizer states are moved from CPU memory to HBM via jax.device_put. The new optimizer states are moved back from HBM to CPU.

An alternative approach to this PR is host computation. Host computation means that optimizer transformations are computed on CPU. Before the start of the computation, gradients and weights are transferred to CPU, and after the computation, their new values are transferred back to HBM. This method has lower HBM footprint, but it's much 2x ~ 3x slower due to slow CPU computation. Also, it's very buggy.

TLDR: to be merged after upgrading jax to 0.4.36.

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

axlearn/common/optimizer_base.py Show resolved Hide resolved
axlearn/common/optimizers.py Outdated Show resolved Hide resolved
axlearn/common/optimizers.py Outdated Show resolved Hide resolved
Comment on lines 2048 to 2050
Only wrap the optimizer that you actually want to offload with this function to avoid
unneseccary overhead. This is usually the optimizer that occupies the most HBM. For example,
when you have chained optimizers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does the overhead come from? Is it from the states of clip_by_global_norm being offloaded? If so, could we use regular expressions to specify which states to offload?

Comment on lines 2092 to 2094
state = jax.device_put(state, TransferToMemoryKind(offload_src))
updates, state = optimizer.update(updates, state, params)
state = jax.device_put(state, TransferToMemoryKind(offload_dst), donate=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need explicit device_put calls here? Is it enough to specify the partition spec with the right memory_kind?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't tested with_sharding_constraint, if you are refering to that. However, specifying full sharding requires us to store the sharding when the partition fn is called, which is not preferred by John (see internal comments)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether we need with_sharding_constraint, since jit/pjit should already apply shardings returned by partition_fn to the states. What happens without these device_put calls?

If necessary, we can always invoke partition_fn here to compute the sharding on the fly (instead of storing them) and apply with_sharding_constraint.

axlearn/common/optimizers.py Outdated Show resolved Hide resolved
@hanzhi713 hanzhi713 requested a review from ruomingp December 4, 2024 21:11
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question about device_put...

axlearn/common/optimizers.py Outdated Show resolved Hide resolved
axlearn/common/optimizers.py Show resolved Hide resolved
axlearn/common/optimizers.py Outdated Show resolved Hide resolved
Comment on lines 2092 to 2094
state = jax.device_put(state, TransferToMemoryKind(offload_src))
updates, state = optimizer.update(updates, state, params)
state = jax.device_put(state, TransferToMemoryKind(offload_dst), donate=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether we need with_sharding_constraint, since jit/pjit should already apply shardings returned by partition_fn to the states. What happens without these device_put calls?

If necessary, we can always invoke partition_fn here to compute the sharding on the fly (instead of storing them) and apply with_sharding_constraint.

@hanzhi713
Copy link
Member Author

A question about device_put...

Before the optimizer can be invoked, the offloaded optimizer states need to be transferred to device memory space. If we remove these device_put calls, we will get errors like xxx is not supported on pined_host memory space, where xxx is some XLA primitive operations such as add (forgot the exact error message but is something like this)

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification on device_put calls. Could you add a comment on why it's necessary? Also two suggestions...

Comment on lines 143 to 148
def copy_partition(
param_specs: Nested[ParameterSpec],
*,
pattern: Union[None, str, re.Pattern] = None,
memory_kind: Optional[MemoryKind] = None,
) -> Nested[OptStateSpec]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of coupling creation of OptStateSpec and setting of memory_kind, how about having a separate function for setting memory kind?

def set_memory_kind(opt_state_spec: Nested[OptStateSpec], *, pattern, memory_kind):

This allows set_memory_kind to be called multiple times, maybe for different memory kind. WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see how set_memory_kind will be different from copy_partition. Signature and implementation will be the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine in the future we have many types of memory kinds, e.g., "remote_host". Then we can do:

opt_state_specs = copy_partition(...)
opt_state_specs = set_memory_kind(..., "pinned_host")
opt_state_specs = set_memory_kind(..., "remote_host")

Copy link
Member Author

@hanzhi713 hanzhi713 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be the same as

opt_state_specs = copy_partition(...)
opt_state_specs = copy_partition(..., "pinned_host")
opt_state_specs = copy_partition(..., "remote_host")

Do you mean that using a separate function is slightly better for readability?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference is that copy_partition also performs the type conversion from Nested[ParameterSpec] to Nested[OptStateSpec].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I can change the type of param_specs in copy_partition to Nested[OptStateSpec] since ParameterSpec is a subclass of OptStateSpec and copy_partition doesn't use any new fields from ParameterSpec. Does this sound good?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. SG.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

axlearn/common/optimizers.py Show resolved Hide resolved
@hanzhi713 hanzhi713 requested a review from ruomingp December 5, 2024 19:16
axlearn/common/optimizers.py Show resolved Hide resolved
pattern: Regex to match the full path of each spec. Matched specs will have their memory
kind replaced with `memory_kind`.
memory_kind: New memory kind. Default to None.
Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns:
Returns:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants