supporting pytorch/jax optimizer #64
Replies: 9 comments
-
Hi, Thanks for opening this issue. On using Pytorch or JAXIn theory, the package could be updated to use pytorch or jax optimizers. However, this would likely require significant changes to the structure of the optimization code. For instance, I expect the current numpy-based computations would need to be adapted to work with torch/jax tensors to ensure compatibility with automatic differentiation. While this is possible, it’s non-trivial and would require a closer look at the code to assess the exact modifications needed. It might also be feasible to refactor only specific parts of the optimization framework, depending on the functionality you’re aiming to use. Another potential approach could involve making the backend configurable (e.g., allowing users to choose between numpy, torch, or jax). On using RLThere are many options here:
I can elaborate on these if you want more info on what the code might look like. I plan to include a simple tutorial demonstrating option 1 in the learning guide in the coming months. Happy to discuss further if you have any other questions. Regards, [1] Tong Yang, Dewen Cheng, and Yongtian Wang, "Designing freeform imaging systems based on reinforcement learning," Opt. Express 28, 30309-30323 (2020) |
Beta Was this translation helpful? Give feedback.
-
Thanks for the elaborate explanations. Option 1 is exactly what I'm planning to pursue and I have been following the recommended paper. If there is any template/rough tutorial for doing this, it would be really helpful. Alternatively, if you can provide some guidance on how the code might look like or how I can encapsulate the optiland functions into RL environments, it would be great too. Thanks! |
Beta Was this translation helpful? Give feedback.
-
I hope to write the tutorial for this topic by January. I'll start with something quite simple as a proof of concept, perhaps an aspheric singlet. If all goes well, I'll move on to more complex designs. The following is my thought process for the problem so far. Overall goal
Approximate outline for generating this:
Outline for a single training step (first thoughts)
Sample code for a configurable aspheric singletclass SingletConfigurable(optic.Optic):
"""A configurable aspheric singlet
Args:
n (float): refractive index
radius (float): radius of curvature of the asphere
t1 (float): thickness of the first surface (lens thickness)
t2 (float): thickness of the second surface (thickness to image plane)
coefficients (list): coefficients of the asphere
"""
def __init__(self, n, radius, t1, t2, coefficients):
super().__init__()
# define the material for the singlet
mat = materials.IdealMaterial(n=n, k=0)
# add surfaces
self.add_surface(index=0, radius=np.inf, thickness=np.inf)
self.add_surface(index=1, thickness=t1, radius=radius, is_stop=True, material=mat,
surface_type='even_asphere', conic=0.0, coefficients=coefficients)
self.add_surface(index=2, thickness=t2)
self.add_surface(index=3)
# add aperture
self.set_aperture(aperture_type='EPD', value=25)
# add field
self.set_field_type(field_type='angle')
self.add_field(y=0)
# add wavelength
self.add_wavelength(value=0.55, is_primary=True) Pseudo-code for the environmentfrom gymnasium import Env
class LensDesignEnv(Env):
def __init__(self):
# define the action and observation space
# configure other parameters
pass
def step(self, action):
# see example procedure above
pass
def reset(self):
# create a random aspheric singlet
# define random parameters here
# define lens
self.lens = SingletConfigurable(...)
def get_reward(self, done=False):
pass
def _get_ob(self):
# get observation
pass
def _get_rms_spot_size(self):
# helper function to calculate the RMS spot size
spot = analysis.SpotDiagram(self.lens)
size = spot.rms_spot_radius()[0][0]
if np.isnan(size):
return 1e3
return size To fill out this code, take a look at the tutorials. For example, see the optimization tutorials to see how to set up and run an optimization. The end state will also depend a bit on what you want to do/optimize. I would be very interested to see if you can get something working. Please follow up, if so. Thanks, Edit: clarify creation of singlet occurs once per episode |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot for the pseudo-code! I will give it a shot and keep you posted on how it works out. |
Beta Was this translation helpful? Give feedback.
-
Just following up here. I generated a Jupyter notebook showing a simple example of optimizing an aspheric singlet via RL. I ended up using a slightly different approach than what I wrote above, but the concept is similar. It works, but there's a lot of room for improvement. Notebook is here: RL_aspheric_singlet.ipynb I'd like to experiment with extending this approach to slightly more complex designs down the road. Let me know if you try anything similar or have suggestions. -Kramer |
Beta Was this translation helpful? Give feedback.
-
Hi Kramer, I actually was able to implement an RL pipeline but with a different objective. I ended up using RL to generate a starting point for the lens designs. It was working fairly well though the training was taking a long time without tuning. I have another question about how to best use the repo. If I want to design a non-left-to-right system that for example has mirrors that guide the rays to split and get reflected vertically, does the codebase support this? |
Beta Was this translation helpful? Give feedback.
-
Hi, Glad to hear you got something working. I also continued with the RL work and began writing a small, modular package for the various components. I also ran into the issue that training was slow on my machine. I will come back to this in the future when I have some time. That code isn't published, but I could put it in a repo if interesting. Also, related, I've begun adding a configurable backend for Optiland so it can use torch or numpy (see branch feat/torch). I still need to finalize the testing, which is taking some time. To answer your question - yes, but you might need to write some custom code. First, the codebase is currently only set up for sequential ray tracing, so if rays split into two paths at some point, then you'd need to trace rays twice - once for each path. Second, by default, rays are generated assuming typical left-to-right propagation, with the object space on the left side. Rays are constructed based on the system properties like field of view and aperture. However, there's no reason why you couldn't manually define a Not quite the same, but here's an example in the docs with some mirrors and lenses. -Kramer |
Beta Was this translation helpful? Give feedback.
-
Hi, thank you for sharing Optiland - it is a very comprehensive and accessible open-source tool for optical design and ray-tracing! Looking forward to trying it for my own projects and seeing its future developments! You might already be aware of some of these tools (see links below), but I wanted to mention them as I think they are great examples illustrating not only the usual benefits of using JAX/Torch as computational backends (JIT, AD, multi-device CPU/GPU) but also the broader capabilities they unlock, e.g., for optimization, automatic and end-to-end lens design, enhanced ML/NN coupling, forward modeling, better integration with other JAX/Torch-based packages, and more - some of which LensAI could directly benefit from. The hardest part may already be done since you’ve built all the core machinery!
JAX may be more challenging due to its functional programming model, but it's definitely becoming increasingly mature and widely adopted. One example is Chromatix (https://github.com/chromatix-team/chromatix) for wave optics, which uses Flax to achieve a more Pythonic, object-oriented style in JAX. Equinox is another popular framework for this. |
Beta Was this translation helpful? Give feedback.
-
Moving this topic into discussions, as the scope has expanded beyond the original issue. Hi @CAClaveau, Thanks a lot for sharing these references! I'm familiar with DeepLens, but I hadn't seen the others you mentioned. I'll consider using some of these for examples in the LensAI repository. These tools have some impressive features, so interested to explore them more, or maybe even interface them to Optiland in some way. If you have suggestions for other examples/features to add, feel free to make a proposal (either here in Optiland or LensAI) by opening an issue/discussion or submitting a PR. I'll look into these tools further and experiment with them a bit. I'm not very familiar with JAX yet, so I'm curious to see what can be done with it. Chromatix looks like a solid example - will definitely follow that one. -Kramer |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I'm looking to implement RL to optimize optics. I see there are currently a couple of optimizers implemented but unsure if I can use Pytorch or JAX optimizer classes instead. If there needs some more work, can you help provide some pointers to where I can modify to do this? Also if you can help provide a rough idea of how to use this code base for RL, that would be much appreciated. Thanks.
Beta Was this translation helpful? Give feedback.
All reactions