Skip to content

Conversation

@mj023
Copy link
Collaborator

@mj023 mj023 commented Aug 28, 2025

This PR will give users the option to run PyLCM with multiple devices (GPU's, other accelerators). To accomplish this, the state space needs to be split up between the devices. This can be done by sharding one of the arrays describing a state space dimension and distributing the shards among the devices. Executing the computation on the right device will then be handled by Jax.

One limitation is that the length of the array of the state dimension that is sharded needs to be a multiple of the number of available devices. There are two solutions to this:

  1. Inform the user when no dimension is divisible without remainder.
  2. Pad the largest state dimension so that the array length is divisible without remainder.

I decided to implement solution 1, because we would need to add a lot off infrastructure to work with padded arrays and the user can easily change the number of gridpoints to fulfill the requirement.

@hmgaudecker
Copy link
Member

Thanks!!!

I can't imagine many cases where the user would not have control over the number of GPUs. So I would just leave it as a configuration option, which state (if any, see below) to break up. Doing this automatically will be fairly complicated once we move to model blocks and not all states might be available in all periods etc.

Why "if any" ?

In the old days on CPU-clusters the communication overhead quickly became dominant. During the solution, the entire value function array for period $t+1$ has to be available on all devices in order to calculate the value function in $t$.

Note that in almost any modern application, one would have a discrete set of preference types (probably 2-5). Think is the most natural point to split across devices and I would go beyond it only once the need arises. Or did that already happen for you?

@mj023
Copy link
Collaborator Author

mj023 commented Aug 29, 2025

Thanks!!!

I can't imagine many cases where the user would not have control over the number of GPUs. So I would just leave it as a configuration option, which state (if any, see below) to break up. Doing this automatically will be fairly complicated once we move to model blocks and not all states might be available in all periods etc.

I think you are right, maybe it would be best to then specify this in the model config? For example with a new parameter Grid constructor?

LinspaceGrid(start=0, stop=10, n_points=10, distributed=True) 

It's probably more for advanced users, who can then come up with the best way to partition the state space depending on their model.

Why "if any" ?

In the old days on CPU-clusters the communication overhead quickly became dominant. During the solution, the entire value function array for period t + 1 has to be available on all devices in order to calculate the value function in t .

I would hope that this is not a problem, at least for multiple GPU's. If they are on one Node in a HPC Cluster the connection between them is much faster than between Nodes in CPU Clusters and Jax would hopefully overlap communication and computation. The problem we have is very similar to the one encountered when training NeuralNets, so Jax should be somewhat optimized to solve it.

You are right, that fixed agent types would be the perfect dimensions to split. Then no communication is needed. If most models have these now, that is great. The main use case I thought of, was more about speeding up computation when one has multiple GPU's on one HPC Cluster Node so they aren't just idle.

@hmgaudecker
Copy link
Member

Great ideas and thanks for illuminating me! Sounds very sensible to implement it eventually.

I'd still suggest waiting for the model blocks, keeping this in mind when designing their interface. What do you think?

@mj023
Copy link
Collaborator Author

mj023 commented Sep 1, 2025

Yes, waiting until we know how the model blocks will look like is probably better. I think the needed changes for this will still be small then.

@timmens timmens added the on hold label Sep 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants