From a58779a9a3e5a35c86a918007c4a104665ba5fb8 Mon Sep 17 00:00:00 2001 From: Nihat Engin Toklu Date: Mon, 22 Jul 2024 17:55:27 +0200 Subject: [PATCH 1/6] Introduce functional genetic algorithm operators Alternative implementations for genetic algorithm operators are introduced. These alternative implementations conform to the paradigm of functional programming, and they are batchable (via simply adding a leftmost dimension to the population, or via `torch.func.vmap`). A custom genetic algorithm can be implemented by combining these operators. --- examples/notebooks/Functional_API/README.md | 1 + .../Functional_API/functional_ops.ipynb | 279 ++++ src/evotorch/operators/__init__.py | 6 +- src/evotorch/operators/functional.py | 1146 +++++++++++++++++ 4 files changed, 1431 insertions(+), 1 deletion(-) create mode 100644 examples/notebooks/Functional_API/functional_ops.ipynb create mode 100644 src/evotorch/operators/functional.py diff --git a/examples/notebooks/Functional_API/README.md b/examples/notebooks/Functional_API/README.md index 810e8ce6..911338f9 100644 --- a/examples/notebooks/Functional_API/README.md +++ b/examples/notebooks/Functional_API/README.md @@ -5,5 +5,6 @@ As an alternative to its object-oriented stateful API, EvoTorch provides an API Here are the examples demonstrating various features of this functional API: - **[Maintaining a batch of populations using the functional EvoTorch API](batched_searches.ipynb)**: This notebook shows how one can efficiently run multiple searches simultaneously, each with its own population and hyperparameter configuration, by maintaining a batch of populations. +- **[Functional genetic algorithm operators](functional_ops.ipynb)**: This notebook shows how one can implement a custom genetic algorithm by combining the genetic algorithm operator implementations provided by the functional API of EvoTorch. - **[Solving constrained optimization problems](constrained.ipynb)**: EvoTorch provides batching-friendly constraint penalization functions that can be used with both the object-oriented API and the functional API. In addition, these constraint penalization functions can be used with gradient-based optimization. This notebook demonstrates these features. - **[Solving reinforcement learning tasks using functional evolutionary algorithms](problem.ipynb)**: The functional evolutionary algorithm implementations of EvoTorch can be used to solve problems that are expressed using the object-oriented core API of EvoTorch. To demonstrate this, this notebook instantiates a `GymNE` problem for the reinforcement learning task "CartPole-v1", and solves it using the functional `pgpe` implementation. diff --git a/examples/notebooks/Functional_API/functional_ops.ipynb b/examples/notebooks/Functional_API/functional_ops.ipynb new file mode 100644 index 00000000..c7973583 --- /dev/null +++ b/examples/notebooks/Functional_API/functional_ops.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "05a384ac-fa4d-434e-a8c0-367350dec224", + "metadata": {}, + "source": [ + "# Genetic algorithm with the help of functional operators\n", + "\n", + "In this notebook, we demonstrate how one can design a genetic algorithm by combining operators implemented within the namespace `evotorch.operators.functional`.\n", + "\n", + "## Introduction\n", + "\n", + "EvoTorch provides the namespace `evotorch.operators.functional` which contains genetic algorithm operators that can be called directly on PyTorch tensors. These operators are implemented in conformance with the functional programming paradigm, meaning that they do not mutate the tensors they receive as arguments (`*`).\n", + "\n", + "A genetic algorithm can be designed by simply calling these genetic algorithm operators in an evolution loop. This way of implementing a genetic algorithm grants the user complete flexibility regarding how and when the operators are to be called, and what extra procedures are to be followed between these operator calls.\n", + "\n", + "## Use cases\n", + "\n", + "**Batched optimization.**\n", + "These operators are defined in such a way that, if they receive a population tensor with 3 or more dimensions (instead of 2 dimensions), the extra leftmost dimensions are interpreted as batch dimensions, and the steps of the operators are broadcast to those batch dimensions. This means that they can work not just on a population, but on a batch of populations, in a vectorized manner.\n", + "\n", + "This feature could be helpful when one has multiple populations (each initialized around different values and/or using different initialization methods), and one wishes to run an evolutionary search on all these populations efficiently.\n", + "\n", + "**Nested optimization.**\n", + "It could be the case that the optimization problem at hand has an inner (nested) optimization problem that needs to be addressed within its fitness function. In such cases, one has to run an inner evolutionary search while evaluating each solution. This inner evolutionary search could be implemented with the help of these functional operators. Considering that each solution of the outer problem ends up with its own inner optimization problem, this way of tackling the inner problem could result in an efficient and vectorized implementation (vectorization would happen across multiple inner optimization problems induced by multiple solutions of the outer problem; see the use case \"Batched optimization\").\n", + "\n", + "---\n", + "\n", + "`(*)` It is to be noted, however, that they _do_ mutate the global random state of PyTorch, because of how they use PyTorch functions such as `torch.randn(...)`, etc." + ] + }, + { + "cell_type": "markdown", + "id": "0f9f42f4-43ad-4766-84ee-8a4fe2a9fe2b", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Example\n", + "\n", + "We now show how to use the functional operators to design a genetic algorithm to solve the Rastrigin problem. To keep the example simple, we do not consider the use cases of batched/nested optimization.\n", + "\n", + "We begin with the necessary imports." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dd28c51-e08b-43b3-8cf9-efecdce49203", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from evotorch.operators import functional as func_ops\n", + "from evotorch.decorators import rowwise\n", + "from datetime import datetime" + ] + }, + { + "cell_type": "markdown", + "id": "8379cbbb-084c-461e-889a-2b334d52c138", + "metadata": {}, + "source": [ + "Below, we have the implementations for the fitness functions `rastrigin` and `sphere`.\n", + "\n", + "Notice how these fitness functions are decorated via `evotorch.decorators.rowwise`.\n", + "This decorator allows the user to implement the function with the assumption that its received argument is a single row (i.e. a 1-dimensional tensor). As an additional behavior, if a function decorated via `@rowwise` receives a tensor with 2 or more dimensions, the operations defined within the decorated function are broadcast across the extra leftmost dimensions. In other words, the extra leftmost dimensions are interpreted as batch dimensions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f1d45aa-50a4-4bc8-90a9-ed75b1bb1fbd", + "metadata": {}, + "outputs": [], + "source": [ + "@rowwise\n", + "def rastrigin(x: torch.Tensor) -> torch.Tensor:\n", + " from math import pi\n", + " A = 10\n", + " [n] = x.shape\n", + " return A * n + torch.sum((x ** 2.0) - (A * torch.cos(2 * pi * x)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea8ddfeb-48de-4eab-9296-9e7862950d77", + "metadata": {}, + "outputs": [], + "source": [ + "@rowwise\n", + "def sphere(x: torch.Tensor) -> torch.Tensor:\n", + " return torch.linalg.norm(x)" + ] + }, + { + "cell_type": "markdown", + "id": "fd1fe365-ae1b-4898-91d0-cb7517ef5f84", + "metadata": {}, + "source": [ + "In this notebook, the variable `f` points to the fitness function whose value we want to minimize:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc6d3958-61e7-4838-a9c8-0c9c3d0e343c", + "metadata": {}, + "outputs": [], + "source": [ + "#f = sphere\n", + "f = rastrigin" + ] + }, + { + "cell_type": "markdown", + "id": "da899fd2-5bfe-41bd-8c76-20d8b5c2bf71", + "metadata": {}, + "source": [ + "Various hyperparameters and problem settings:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6625c172-0ec2-492e-9851-0a60594d9086", + "metadata": {}, + "outputs": [], + "source": [ + "popsize = 1000 # population size\n", + "solution_length = 1000 # length of a solution\n", + "\n", + "# lower and upper bounds for the decision values of the initial population\n", + "lb = -5.12\n", + "ub = 5.12\n", + "\n", + "tournament_size = 8 # tournament size\n", + "mutation_stdev = 0.01 # standard deviation for the Gaussian mutation\n", + "eta = 10.0 # eta value for the simulated binary cross-over\n", + "\n", + "num_generations = 1000 # number of generations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36782df7-9728-44e2-ab24-312a0e83b0d2", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the decision values of the initial population\n", + "parents = (torch.rand(popsize, solution_length, dtype=torch.float32) * (ub - lb)) + lb\n", + "\n", + "# Evaluate the initial population\n", + "parent_evals = f(parents)\n", + "\n", + "last_reporting_time = None\n", + "reporting_interval = 1\n", + "\n", + "# Main loop of the population\n", + "for generation in range(1, 1 + num_generations):\n", + "\n", + " # Given the parent solutions and their evaluation results,\n", + " # run a tournament, pick pairs, and apply cross-over for each pair:\n", + " candidates = func_ops.simulated_binary_cross_over(\n", + " parents,\n", + " parent_evals,\n", + " eta=eta,\n", + " tournament_size=tournament_size,\n", + " objective_sense=\"min\",\n", + " )\n", + "\n", + " # Instead of simulated binary cross-over, you could use two-point\n", + " # cross-over:\n", + " # candidates = func_ops.two_point_cross_over(\n", + " # parents,\n", + " # parent_evals,\n", + " # tournament_size=tournament_size,\n", + " # objective_sense=\"min\",\n", + " # )\n", + "\n", + " # Apply Gaussian mutation on the newly made candidate solutions\n", + " candidates = candidates + (torch.randn_like(candidates) * mutation_stdev)\n", + "\n", + " # On the newly mutated solutions, apply the permutation operator of the CoSyNE algorithm\n", + " permuted = func_ops.cosyne_permutation(parents, permute_all=True)\n", + " # Add the permutation results onto the new population of candidates\n", + " candidates = func_ops.combine(candidates, permuted)\n", + "\n", + " # Evaluate all the candidate solutions\n", + " candidate_evals = f(candidates)\n", + "\n", + " # Combine the parent population and the candidate population to form an\n", + " # extended population. This time, we combine together with the evaluation results.\n", + " extended_population, extended_evals = (\n", + " func_ops.combine((parents, parent_evals), (candidates, candidate_evals))\n", + " )\n", + "\n", + " # From the extended population, take the best `popsize` number of solutions.\n", + " # These taken solutions will server as the parents of the next generation.\n", + " parents, parent_evals = (\n", + " func_ops.take_best(extended_population, extended_evals, popsize, objective_sense=\"min\")\n", + " )\n", + "\n", + " # Report how the evolution is progressing\n", + " now = datetime.now()\n", + " if (\n", + " (last_reporting_time is None)\n", + " or (generation == num_generations)\n", + " or ((now - last_reporting_time).total_seconds() > reporting_interval)\n", + " ):\n", + " last_reporting_time = now\n", + " print(\"Generation:\", generation, \" Best eval of population:\", parent_evals.min())" + ] + }, + { + "cell_type": "markdown", + "id": "84b24feb-1e84-4de6-ae24-ca277f5f5f4d", + "metadata": {}, + "source": [ + "Decision values of the final population:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21898601-e05d-489f-8cc1-3beef8bdd33b", + "metadata": {}, + "outputs": [], + "source": [ + "parents" + ] + }, + { + "cell_type": "markdown", + "id": "a067340a-5c87-4406-82f4-9cdcccaa3349", + "metadata": {}, + "source": [ + "Best solution of the final population:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28fccbc1-c605-4b3f-9545-9e4de01b7f07", + "metadata": {}, + "outputs": [], + "source": [ + "pop_best, pop_best_eval = func_ops.take_best(parents, parent_evals, objective_sense=\"min\")\n", + "\n", + "print(\"Best solution of the final population:\", pop_best)\n", + "print(\"Best evaluation result of the final population:\", pop_best_eval)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/evotorch/operators/__init__.py b/src/evotorch/operators/__init__.py index 073480b4..d205cfaa 100644 --- a/src/evotorch/operators/__init__.py +++ b/src/evotorch/operators/__init__.py @@ -69,6 +69,10 @@ """ __all__ = ( + "base", + "functional", + "real", + "sequence", "CopyingOperator", "CosynePermutation", "CrossOver", @@ -84,7 +88,7 @@ ) -from . import base, real, sequence +from . import base, functional, real, sequence from .base import CopyingOperator, CrossOver, Operator, SingleObjOperator from .real import ( CosynePermutation, diff --git a/src/evotorch/operators/functional.py b/src/evotorch/operators/functional.py new file mode 100644 index 00000000..01004d2b --- /dev/null +++ b/src/evotorch/operators/functional.py @@ -0,0 +1,1146 @@ +# Copyright 2024 NNAISENSE SA +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Functional implementations of the genetic algorithm operators. + +Instead of the object-oriented genetic algorithm API +([GeneticAlgorithm][evotorch.algorithms.ga.GeneticAlgorithm]), one might +wish to adopt a style that is more compatible with the functional programming +paradigm. For such cases, the functional operators within this namespace can +be used. + +The operators within this namespace are designed to be called directly, +allowing one to implement a genetic algorithm according to which, how, and +when these operators are used. + +**Reasons for using the functional operators.** + +- **Flexibility.** This API provides various genetic-algorithm-related + operators and gets out of the picture. The user has the complete control + over what happens between this operator calls, and in what order these + operators are used. +- **Batched search.** These functional operators are designed in such a way + that, if they receive a batched population instead of a single population + (i.e. if they receive a 3-or-more-dimensional tensor instead of a + 2-dimensional tensor), they will broadcast their operations across the + extra leftmost dimensions. This allows one to implement a genetic + algorithm that works across many populations at once, in a vectorized + manner. +- **Nested optimization.** It could be the case that the optimization problem + at hand has an inner optimization problem within its fitness function. + This inner optimization problem could be tackled with the help of a + genetic algorithm built upon these functional operators. + Such an approach would allow the user to run a search for each inner + optimization problem across the entire population of the outer + problem, in a vectorized manner (see the previous point titled + "batched search"). + +**Example usage.** +Let us assume that we have the following cost function that to be minimized: + +```python +import torch + + +def f(x: torch.Tensor) -> torch.Tensor: + return torch.linalg.norm(x - 1, dim=-1) +``` + +A genetic algorithm could be designed with the help of these functional +operators as follows: + +```python +import torch +from evotorch.operators.functional import two_point_cross_over, combine, take_best + + +def f(x: torch.Tensor) -> torch.Tensor: + return torch.linalg.norm(x - 1, dim=-1) + + +popsize = 100 # population size +solution_length = 20 # length of a solution +num_generations = 200 # number of generations +mutation_stdev = 0.01 # standard deviation for mutation +tournament_size = 4 # size for the tournament selection + +# Randomly initialize a population, and compute the solution costs +population = torch.randn(popsize, solution_length) +costs = f(population) + +# Initialize the variables that will store the decision values and the cost +# for the last population's best solution. +pop_best_values = None +pop_best_cost = None + +# main loop of the optimization +for generation in range(1, 1 + num_generations): + # Given the population and the solution costs, pick parents and apply + # cross-over on them. + candidates = two_point_cross_over( + population, + costs, + tournament_size=tournament_size, + objective_sense="min", + ) + + # Apply Gaussian mutation on the candidates + candidates = candidates + (torch.randn(popsize, solution_length) * mutation_stdev) + + # Compute the solution costs of the candidates + candidate_costs = f(candidates) + + # Combine the parents and the candidates into an extended population + extended_values, extended_costs = combine( + (population, costs), + (candidates, candidate_costs), + ) + + # Take the best `popsize` number of solutions from the extended population + population, costs = take_best( + extended_values, + extended_costs, + popsize, + objective_sense="min", + ) + + # Take the best solution and its cost + pop_best_values, pop_best_cost = take_best(population, costs, objective_sense="min") + + # Print the status + print("Generation:", generation, " Best cost within population:", best_cost) + +# Print the result +print() +print("Best solution of the last population:") +print(pop_best_values) +print("Cost of the best solution of the last population:") +print(pop_best_cost) +``` +""" + + +from typing import Optional, Union + +import torch + +from evotorch.decorators import expects_ndim + + +@expects_ndim(2, 1, 1, None, randomness="different") +def _pick_solution_via_tournament( + solutions: torch.Tensor, + evals: torch.Tensor, + indices: torch.Tensor, + objective_sense: str, +) -> tuple: + """ + Run a single tournament among multiple solutions to pick the best. + + Args: + solutions: Decision values of the solutions, as a tensor of at least + 2 dimensions. Extra leftmost dimensions will be considered as + batch dimensions. + evals: Evaluation results (i.e. fitnesses) of the solutions, as a + tensor with at least 1 dimension. Extra leftmost dimensions will + be considered as batch dimensions. + indices: Indices of solutions that participate into the tournament, + as a tensor of integers with at least 1 dimension. Extra leftmost + dimensions will be considered as batch dimensions. + objective_sense: A string with value 'min' or 'max', representing the + goal of the optimization. + Returns: + A tuple of the form `(decision_values, eval_result)` where + `decision_values` is the tensor that contains the decision values + of the winning solution(s), and `eval_result` is a tensor that + contains the evaluation result(s) (i.e. fitness(es)) of the + winning solution(s). + """ + # Get the evaluation results of the solutions that participate into the tournament + competing_evals = torch.index_select(evals, 0, indices) + + if objective_sense == "max": + # If the objective sense is 'max', we are looking for the solution with the highest evaluation result + argbest = torch.argmax + elif objective_sense == "min": + # If the objective sense is 'min', we are looking for the solution with the lowest evaluation result + argbest = torch.argmin + else: + raise ValueError( + "`objective_sense` was expected either as 'min' or as 'max'." + f" However, it was received as {repr(objective_sense)}." + ) + + # Among the competing solutions, which one is the best? + winner_competing_eval_index = argbest(competing_evals) + + # Get the index (within the original `solutions`) of the winning solution + winner_solution_index = torch.index_select(indices, 0, winner_competing_eval_index.reshape(1)) + + # Get the decision values and the evaluation result of the winning solution + winner_solution = torch.squeeze(torch.index_select(solutions, 0, winner_solution_index), dim=0) + winner_eval = torch.squeeze(torch.index_select(evals, 0, winner_solution_index), dim=0) + + # Return the winning solution's decision values and evaluation results + return winner_solution, winner_eval + + +@expects_ndim(2, 1, None, None, None, randomness="different") +def _tournament( + solutions: torch.Tensor, + evals: torch.Tensor, + num_tournaments: int, + tournament_size: int, + objective_sense: str, +) -> tuple: + """ + Randomly pick solutions, put them into a tournament, pick the winners. + + Args: + solutions: Decision values of the solutions + evals: Evaluation results of the solutions + num_tournaments: Number of tournaments that will be applied. + In other words, number of winners that will be picked. + tournament_size: Number of solutions to be picked for the tournament + objective_sense: A string of value 'min' or 'max', representing the + goal of the optimization + Returns: + A tuple of the form `(decision_values, eval_results)` where + `decision_values` is the tensor that contains the decision values + of the winning solutions, and `eval_result` is a tensor that + contains the evaluation results (i.e. fitnesses) of the + winning solutions. + """ + if tournament_size < 1: + raise ValueError( + "The argument `tournament_size` was expected to be greater than or equal to 1." + f" However, it was encountered as {tournament_size}." + ) + popsize, _ = solutions.shape + indices_for_tournament = torch.randint_like( + solutions[:1, :1].expand(num_tournaments, tournament_size), 0, popsize, dtype=torch.int64 + ) + return _pick_solution_via_tournament(solutions, evals, indices_for_tournament, objective_sense) + + +@expects_ndim(2, randomness="different") +def _pair_solutions_for_cross_over(solutions: torch.Tensor) -> tuple: + """ + Split the solutions to make its 1st and 2nd halves the 1st and 2nd parents. + + Args: + solutions: A tensor of decision values that are subject to pairing. + Must be at least 2-dimensional. Extra leftmost dimensions will + be interpreted as batch dimensions. + num_children: Number of children, as an integer. Assuming that each + cross-over operation will generate two children, the number of + pairs to be picked from within `solutions` will be the half + of `num_children`. + Returns: + A tuple of the form `(parents1, parents2)` where both parent items + are (at least) 2-dimensional tensors. In the non-batched case, this + resulting tuple indicates that `parents1[i, :]` is paired with + `parents2[i, :]`. + """ + popsize, _ = solutions.shape + + # Ensure that the number of solutions is divisible by 2. + if (popsize % 2) != 0: + raise ValueError(f"The number of `solutions` was expected as an even number. However, it is {popsize}.") + + # Compute the number of pairs to be generated as the half of `num_children`. + num_pairings = popsize // 2 + + return solutions[:num_pairings, :], solutions[num_pairings:, :] + + +@expects_ndim(1, 1, None, randomness="different") +def _do_cross_over_between_two_solutions(solution1: torch.Tensor, solution2: torch.Tensor, num_points: int) -> tuple: + """ + Do cross-over between two solutions (or between batches of solutions). + + Args: + solution1: A tensor, with at least 1 dimension, representing the + decision values of the first parent(s). + solution1: A tensor, with at least 1 dimension, representing the + decision values of the second parent(s). + num_points: Number of cutting points for the cross-over operation. + Returns: + A tuple of the form `(child1, child2)`, representing the decision + values of the generated child solution(s). + """ + device = solution1.device + [solution_length] = solution1.shape + + # Randomly generate the tensor `cut_points` that represents the indices at which the decision values of the + # parent solutions will be cut. + like_what = (solution1[:1] + solution2[:1]).reshape(tuple()).expand(num_points) + cut_points = torch.randint_like(like_what, 1, solution_length - 1, dtype=torch.int64) + + item_indices = torch.arange(solution_length, dtype=torch.int64, device=device) + + # Initialize the tensor `switch_parent` as a tensor filled with False. + switch_parent = torch.zeros(solution_length, dtype=torch.bool, device=device) + + # For each cutting point, flip the booleans within `switch_parent` whose indices are greater than or equal to + # the encountered cutting point. + for i_num_point in range(num_points): + cut_point_index = torch.as_tensor([i_num_point], dtype=torch.int64, device=device) + cut_point = torch.index_select(cut_points, 0, cut_point_index).reshape(tuple()) + switch_parent = (item_indices >= cut_point) ^ switch_parent + + dont_switch_parent = ~switch_parent + + # If `switch_parent` is False, child1 takes its value from solution1. + # If `switch_parent` is True, child1 takes its value from solution2. + child1 = (dont_switch_parent * solution1) + (switch_parent * solution2) + + # If `switch_parent` is False, child2 takes its value from solution2. + # If `switch_parent` is True, child2 takes its value from solution1. + child2 = (dont_switch_parent * solution2) + (switch_parent * solution1) + + # Return the generated child solutions + return child1, child2 + + +@expects_ndim(2, None, None, randomness="different") +def _do_cross_over(solutions: torch.Tensor, num_points: int) -> torch.Tensor: + """ + Apply cross-over on multiple solutions. + + Args: + solutions: Decision values of the parent solutions, as a tensor with + at least 2 dimensions. Extra leftmost dimensions will be considered + batch dimensions. + num_points: Number of cutting points for when applying cross-over. + Returns: + A tensor with at least 2 dimensions, representing the decision values + of the child solutions. + """ + parents1, parents2 = _pair_solutions_for_cross_over(solutions) + children1, children2 = _do_cross_over_between_two_solutions(parents1, parents2, num_points) + return torch.vstack([children1, children2]) + + +def multi_point_cross_over( + parents: torch.Tensor, + evals: Optional[torch.Tensor] = None, + *, + num_points: int, + tournament_size: Optional[int] = None, + num_children: Optional[int] = None, + objective_sense: Optional[str] = None, +) -> torch.Tensor: + """ + Apply multi-point cross-over on the given `parents`. + + If `tournament_size` is given, parents for the cross-over operation will + be picked with the help of a tournament. Otherwise, the first half of the + given `parents` will be the first set of parents, and the second half + of the given `parents` will be the second set of parents. + + The return value of this function is a new tensor containing the decision + values of the child solutions. + + Args: + parents: A tensor with at least 2 dimensions, representing the decision + values of the parent solutions. If this tensor has more than 2 + dimensions, the extra leftmost dimension(s) will be considered as + batch dimensions. + evals: A tensor with at least 1 dimension, representing the evaluation + results (i.e. fitnesses) of the parent solutions. If this tensor + has more than 1 dimension, the extra leftmost dimension(s) will be + considered as batch dimensions. If `tournament_size` is not given, + `evals` can be left as None. + num_points: Number of points at which the decision values of the + parent solutions will be cut and recombined to form the child + solutions. + tournament_size: If given as an integer that is greater than or equal + to 1, the parents for the cross-over operation will be picked + with the help of a tournament. In more details, each parent will + be picked as the result of comparing multiple competing solutions, + the number of these competing solutions being equal to this + `tournament_size`. Please note that, if `tournament_size` is given + as an integer, the arguments `evals` and `objective_sense` are + also required. If `tournament_size` is left as None, the first half + of `parents` will be the first set of parents, and the second half + of `parents` will be the second set of parents. + num_children: Optionally the number of children to produce as the + result of tournament selection and cross-over, as an even integer. + If tournament selection is enabled (i.e. if `tournament_size` is + an integer) but `num_children` is omitted, the number of children + will be equal to the number of `parents`. + If there is no tournament selection (i.e. if `tournament_size` is + None), `num_children` is expected to be None. + objective_sense: Mandatory if `tournament_size` is not None. + `objective_sense` is expected as 'max' for when the goal of the + optimization is to maximize `evals`, 'min' for when the goal of + the optimization is to minimize `evals`. If `tournament_size` + is None, `objective_sense` can also be left as None. + Returns: + Decision values of the child solutions, as a new tensor. + """ + + if tournament_size is None: + if num_children is not None: + raise ValueError( + "`num_children` was received as something other than None." + " However, `num_children` is expected only when a `tournament_size` is given," + " which seems to be omitted (i.e. which is None)." + ) + else: + # This is the case where the tournament selection feature is enabled. + # We first ensure that the required arguments `evals` and `objective_sense` are available. + if evals is None: + raise ValueError( + "When a `tournament_size` is given, the argument `evals` is also required." + " However, it was received as None." + ) + if num_children is None: + # If `num_children` is not given, we make it equal to the number of `parents`. + num_children = parents.shape[-2] + if objective_sense is None: + raise ValueError( + "When a `tournament_size` is given, the argument `objective_sense` is also required." + " However, it was received as None." + ) + + # Apply tournament selection on the original `parents` + parents, _ = _tournament(parents, evals, num_children, tournament_size, objective_sense) + + # Apply the cross-over operation on `parents`, and return the recombined decision values tensor. + return _do_cross_over(parents, num_points) + + +def one_point_cross_over( + parents: torch.Tensor, + evals: Optional[torch.Tensor] = None, + *, + tournament_size: Optional[int] = None, + num_children: Optional[int] = None, + objective_sense: Optional[str] = None, +) -> torch.Tensor: + """ + Apply one-point cross-over on the given `parents`. + + Let us assume that we have the following two parent solutions: + + ```text + ________________________ + parentA | a1 a2 a3 a4 a5 a6 | + parentB | b1 b2 b3 b4 b5 b6 | + |________________________| + ``` + + This cross-over operation will first randomly decide a cutting point: + + ```text + ________|________________ + parentA | a1 a2 | a3 a4 a5 a6 | + parentB | b1 b2 | b3 b4 b5 b6 | + |________|________________| + | + ``` + + ...and then form the following child solutions by recombining the decision + values of the parents: + + ```text + ________|________________ + child1 | a1 a2 | b3 b4 b5 b6 | + child2 | b1 b2 | a3 a4 a5 a6 | + |________|________________| + | + ``` + + If `tournament_size` is given, parents for the cross-over operation will + be picked with the help of a tournament. Otherwise, the first half of the + given `parents` will be the first set of parents, and the second half + of the given `parents` will be the second set of parents. + + The return value of this function is a new tensor containing the decision + values of the child solutions. + + Args: + parents: A tensor with at least 2 dimensions, representing the decision + values of the parent solutions. If this tensor has more than 2 + dimensions, the extra leftmost dimension(s) will be considered as + batch dimensions. + evals: A tensor with at least 1 dimension, representing the evaluation + results (i.e. fitnesses) of the parent solutions. If this tensor + has more than 1 dimension, the extra leftmost dimension(s) will be + considered as batch dimensions. If `tournament_size` is not given, + `evals` can be left as None. + tournament_size: If given as an integer that is greater than or equal + to 1, the parents for the cross-over operation will be picked + with the help of a tournament. In more details, each parent will + be picked as the result of comparing multiple competing solutions, + the number of these competing solutions being equal to this + `tournament_size`. Please note that, if `tournament_size` is given + as an integer, the arguments `evals` and `objective_sense` are + also required. If `tournament_size` is left as None, the first half + of `parents` will be the first set of parents, and the second half + of `parents` will be the second set of parents. + num_children: Optionally the number of children to produce as the + result of tournament selection and cross-over, as an even integer. + If tournament selection is enabled (i.e. if `tournament_size` is + an integer) but `num_children` is omitted, the number of children + will be equal to the number of `parents`. + If there is no tournament selection (i.e. if `tournament_size` is + None), `num_children` is expected to be None. + objective_sense: Mandatory if `tournament_size` is not None. + `objective_sense` is expected as 'max' for when the goal of the + optimization is to maximize `evals`, 'min' for when the goal of + the optimization is to minimize `evals`. If `tournament_size` + is None, `objective_sense` can also be left as None. + Returns: + Decision values of the child solutions, as a new tensor. + """ + return multi_point_cross_over( + parents, + evals, + num_points=1, + num_children=num_children, + tournament_size=tournament_size, + objective_sense=objective_sense, + ) + + +def two_point_cross_over( + parents: torch.Tensor, + evals: Optional[torch.Tensor] = None, + *, + tournament_size: Optional[int] = None, + num_children: Optional[int] = None, + objective_sense: Optional[str] = None, +) -> torch.Tensor: + """ + Apply two-point cross-over on the given `parents`. + + Let us assume that we have the following two parent solutions: + + ```text + ________________________ + parentA | a1 a2 a3 a4 a5 a6 | + parentB | b1 b2 b3 b4 b5 b6 | + |________________________| + ``` + + This cross-over operation will first randomly decide two cutting points: + + ```text + ________|____________|____ + parentA | a1 a2 | a3 a4 a5 | a6 | + parentB | b1 b2 | b3 b4 b5 | b6 | + |________|____________|____| + | | + ``` + + ...and then form the following child solutions by recombining the decision + values of the parents: + + ```text + ________|____________|____ + child1 | a1 a2 | b3 b4 b5 | a6 | + child2 | b1 b2 | a3 a4 a5 | b6 | + |________|____________|____| + | | + ``` + + If `tournament_size` is given, parents for the cross-over operation will + be picked with the help of a tournament. Otherwise, the first half of the + given `parents` will be the first set of parents, and the second half + of the given `parents` will be the second set of parents. + + The return value of this function is a new tensor containing the decision + values of the child solutions. + + Args: + parents: A tensor with at least 2 dimensions, representing the decision + values of the parent solutions. If this tensor has more than 2 + dimensions, the extra leftmost dimension(s) will be considered as + batch dimensions. + evals: A tensor with at least 1 dimension, representing the evaluation + results (i.e. fitnesses) of the parent solutions. If this tensor + has more than 1 dimension, the extra leftmost dimension(s) will be + considered as batch dimensions. If `tournament_size` is not given, + `evals` can be left as None. + tournament_size: If given as an integer that is greater than or equal + to 1, the parents for the cross-over operation will be picked + with the help of a tournament. In more details, each parent will + be picked as the result of comparing multiple competing solutions, + the number of these competing solutions being equal to this + `tournament_size`. Please note that, if `tournament_size` is given + as an integer, the arguments `evals` and `objective_sense` are + also required. If `tournament_size` is left as None, the first half + of `parents` will be the first set of parents, and the second half + of `parents` will be the second set of parents. + num_children: Optionally the number of children to produce as the + result of tournament selection and cross-over, as an even integer. + If tournament selection is enabled (i.e. if `tournament_size` is + an integer) but `num_children` is omitted, the number of children + will be equal to the number of `parents`. + If there is no tournament selection (i.e. if `tournament_size` is + None), `num_children` is expected to be None. + objective_sense: Mandatory if `tournament_size` is not None. + `objective_sense` is expected as 'max' for when the goal of the + optimization is to maximize `evals`, 'min' for when the goal of + the optimization is to minimize `evals`. If `tournament_size` + is None, `objective_sense` can also be left as None. + Returns: + Decision values of the child solutions, as a new tensor. + """ + return multi_point_cross_over( + parents, + evals, + num_points=2, + num_children=num_children, + tournament_size=tournament_size, + objective_sense=objective_sense, + ) + + +@expects_ndim(1, 1, 0, randomness="different") +def _do_sbx_between_two_solutions(parent1: torch.Tensor, parent2: torch.Tensor, eta: torch.Tensor) -> tuple: + u = torch.rand_like(parent1) + + beta = torch.where( + u <= 0.5, + (2 * u) ** (1.0 / (eta + 1.0)), + (1 / (2 * (1.0 - u))) ** (1.0 / (eta + 1.0)), + ) + + child1 = 0.5 * (((1 + beta) * parent1) + ((1 - beta) * parent2)) + child2 = 0.5 * (((1 - beta) * parent1) + ((1 + beta) * parent2)) + + return child1, child2 + + +@expects_ndim(2, 0, randomness="different") +def _do_sbx(solutions: torch.Tensor, eta: Union[float, torch.Tensor]) -> torch.Tensor: + parents1, parents2 = _pair_solutions_for_cross_over(solutions) + children1, children2 = _do_sbx_between_two_solutions(parents1, parents2, eta) + return torch.vstack([children1, children2]) + + +def simulated_binary_cross_over( + parents: torch.Tensor, + evals: Optional[torch.Tensor] = None, + *, + eta: Union[float, torch.Tensor], + tournament_size: Optional[int] = None, + num_children: Optional[int] = None, + objective_sense: Optional[str] = None, +) -> torch.Tensor: + """ + Apply simulated binary cross-over (SBX) on the given `parents`. + + If `tournament_size` is given, parents for the cross-over operation will + be picked with the help of a tournament. Otherwise, the first half of the + given `parents` will be the first set of parents, and the second half + of the given `parents` will be the second set of parents. + + The return value of this function is a new tensor containing the decision + values of the child solutions. + + Args: + parents: A tensor with at least 2 dimensions, representing the decision + values of the parent solutions. If this tensor has more than 2 + dimensions, the extra leftmost dimension(s) will be considered as + batch dimensions. + evals: A tensor with at least 1 dimension, representing the evaluation + results (i.e. fitnesses) of the parent solutions. If this tensor + has more than 1 dimension, the extra leftmost dimension(s) will be + considered as batch dimensions. If `tournament_size` is not given, + `evals` can be left as None. + eta: The crowding index, expected as a real number. Bigger eta values + result in children closer to their parents. If `eta` is given as + an `n`-dimensional tensor instead of a scalar, those extra + dimensions will be considered as batch dimensions. + tournament_size: If given as an integer that is greater than or equal + to 1, the parents for the cross-over operation will be picked + with the help of a tournament. In more details, each parent will + be picked as the result of comparing multiple competing solutions, + the number of these competing solutions being equal to this + `tournament_size`. Please note that, if `tournament_size` is given + as an integer, the arguments `evals` and `objective_sense` are + also required. If `tournament_size` is left as None, the first half + of `parents` will be the first set of parents, and the second half + of `parents` will be the second set of parents. + num_children: Optionally the number of children to produce as the + result of tournament selection and cross-over, as an even integer. + If tournament selection is enabled (i.e. if `tournament_size` is + an integer) but `num_children` is omitted, the number of children + will be equal to the number of `parents`. + If there is no tournament selection (i.e. if `tournament_size` is + None), `num_children` is expected to be None. + objective_sense: Mandatory if `tournament_size` is not None. + `objective_sense` is expected as 'max' for when the goal of the + optimization is to maximize `evals`, 'min' for when the goal of + the optimization is to minimize `evals`. If `tournament_size` + is None, `objective_sense` can also be left as None. + Returns: + Decision values of the child solutions, as a new tensor. + """ + if tournament_size is None: + if num_children is not None: + raise ValueError( + "`num_children` was received as something other than None." + " However, `num_children` is expected only when a `tournament_size` is given," + " which seems to be omitted (i.e. which is None)." + ) + else: + # This is the case where the tournament selection feature is enabled. + # We first ensure that the required arguments `evals` and `objective_sense` are available. + if evals is None: + raise ValueError( + "When a `tournament_size` is given, the argument `evals` is also required." + " However, it was received as None." + ) + if num_children is None: + # If `num_children` is not given, we make it equal to the number of `parents`. + num_children = parents.shape[-2] + if objective_sense is None: + raise ValueError( + "When a `tournament_size` is given, the argument `objective_sense` is also required." + " However, it was received as None." + ) + + # Apply tournament selection on the original `parents` + parents, _ = _tournament(parents, evals, num_children, tournament_size, objective_sense) + + return _do_sbx(parents, eta) + + +@expects_ndim(1, None, None) +def _utility(evals: torch.Tensor, objective_sense: str, ranking_method: Optional[str] = "centered") -> torch.Tensor: + """ + Return utility values representing how good the evaluation results are. + + Args: + evals: An at least 1-dimensional tensor that stores evaluation results + (i.e. fitness values). Extra leftmost dimensions will be taken as + batch dimensions. + objective_sense: A string whose value is either 'min' or 'max', which + represents the goal of the optimization (minimization or + maximization). + ranking_method: Ranking method according to which the utilities will + be computed. Currently, this function supports: + 'centered' (worst one gets -0.5, best one gets 0.5); + 'linear' (worst one gets 0.0, best one gets 1.0); + 'raw' (evaluation results themselves are returned, with the + additional behavior of flipping the signs if `objective_sense` + is 'min', ensuring that the worst evaluation result gets the + lowest value, and the best evaluation result gets the highest + value). None also means 'raw'. + Returns: + Utility values, as a tensor whose shape is the same with the shape of + `evals`. + """ + if objective_sense == "min": + # If the objective sense is 'min', we set `descending=True`, so that the order will be inverted, and the + # highest number in `evals` will end up at index 0 (and therefore with the lowest rank). + descending = True + elif objective_sense == "max": + # If the objective sense is 'max', we set `descending=False`, so that the order of sorting will be from + # lowest to highest, and therefore, the highest number in `evals` will end up at the highest index + # (and therefore with the highest rank). + descending = False + else: + raise ValueError(f"Expected `objective_sense` as 'min' or 'max', but received it as {repr(objective_sense)}") + + if (ranking_method is None) or (ranking_method == "raw"): + # This is the case where `ranking_method` is "raw" (or is None), which means that we do not even need to + # do sorting. We can just use `evals` itself. + if descending: + # If `descending` is True, we are in the case that the objective sense is 'min'. + # In this case, the highest number within `evals` should have the lowest utility, and the lowest number + # within `evals` should have the highest utility. To ensure this, we flip the signs and return the result. + return -evals + else: + # If `descending` is False, there is nothing to do. We can just return `evals` as it is. + return evals + + [n] = evals.shape + increasing_indices = torch.arange(n, device=evals.device) + + # Compute the ranks, initially in the form of indices (i.e. worst one gets 0, best one gets n-1) + indices_for_sorting = torch.argsort(evals, descending=descending) + ranks = torch.empty_like(indices_for_sorting) + ranks[indices_for_sorting] = increasing_indices + + if ranking_method == "linear": + # Rescale the ranks so that the worst one gets 0.0, and the best one gets 1.0. + ranks = ranks / (n - 1) + elif ranking_method == "centered": + # Rescale and shift the ranks so that the worst one gets -0.5, and the best one gets +0.5. + ranks = (ranks / (n - 1)) - 0.5 + else: + raise ValueError(f"Unrecognized ranking method: {repr(ranking_method)}") + + return ranks + + +def utility(evals: torch.Tensor, *, objective_sense: str, ranking_method: Optional[str] = "centered") -> torch.Tensor: + """ + Return utility values representing how good the evaluation results are. + + A utility number is different from `evals` in the sense that, worst + solution has the lowest utility, and the best solution has the highest + utility, regardless of the objective sense ('min' or 'max'). + On the other hand, the lowest number within `evals` could represent + the fitness of the best solution or of the worst solution, depending + on the objective sense. + + The "centered" ranking is the same ranking method that was used within: + + ``` + Tim Salimans, Jonathan Ho, Xi Chen, Szymon Sidor, Ilya Sutskever (2017). + Evolution Strategies as a Scalable Alternative to Reinforcement Learning + ``` + + Args: + evals: An at least 1-dimensional tensor that stores evaluation results + (i.e. fitness values). Extra leftmost dimensions will be taken as + batch dimensions. + objective_sense: A string whose value is either 'min' or 'max', which + represents the goal of the optimization (minimization or + maximization). + ranking_method: Ranking method according to which the utilities will + be computed. Currently, this function supports: + 'centered' (worst one gets -0.5, best one gets 0.5); + 'linear' (worst one gets 0.0, best one gets 1.0); + 'raw' (evaluation results themselves are returned, with the + additional behavior of flipping the signs if `objective_sense` + is 'min', ensuring that the worst evaluation result gets the + lowest value, and the best evaluation result gets the highest + value). None also means 'raw'. + Returns: + Utility values, as a tensor whose shape is the same with the shape of + `evals`. + """ + return _utility(evals, objective_sense, ranking_method) + + +@expects_ndim(1, randomness="different") +def _cosyne_permutation_for_entire_subpopulation(subpopulation: torch.Tensor) -> torch.Tensor: + """ + Return the permuted (i.e. shuffled) version of the given subpopulation. + + In the context of the Cosyne algorithm, a "subpopulation" is the population + of decision values for a single decision variable. Therefore, subpopulation + represents a column of an entire population. + + Args: + subpopulation: Population of decision values for a single decision + variable. Expected as an at least 1-dimensional tensor. Extra + leftmost dimensions will be considered as batch dimensions. + Returns: + Shuffled version of the given subpopulation. + """ + return subpopulation[torch.argsort(torch.rand_like(subpopulation))] + + +@expects_ndim(1, 1, None, randomness="different") +def _partial_cosyne_permutation_for_subpopulation( + subpopulation: torch.Tensor, evals: torch.Tensor, objective_sense: Optional[str] +) -> torch.Tensor: + """ + Return the permuted (i.e. shuffled) version of the given subpopulation. + + In the context of the Cosyne algorithm, a "subpopulation" is the population + of decision values for a single decision variable. Therefore, subpopulation + represents a column of an entire population. + + Probabilistically, some items within the given subpopulation stay the same. + In more details, if an item belongs to a solution that has better fitness, + that item has lower probability to change. + + Args: + subpopulation: Population of decision values for a single decision + variable. Expected as an at least 1-dimensional tensor. Extra + leftmost dimensions will be considered as batch dimensions. + evals: Evaluation results (i.e. fitnesses). + objective_sense: A string whose value is either 'min' or 'max', + representing the goal of the optimization. + Returns: + Shuffled version of the given subpopulation. + """ + permuted = _cosyne_permutation_for_entire_subpopulation(subpopulation) + + [n] = subpopulation.shape + [num_evals] = evals.shape + + if n != num_evals: + raise ValueError(f"The population size is {n}, but the number of evaluations is different ({num_evals})") + + ranks = utility(evals, objective_sense=objective_sense, ranking_method="linear") + permutation_probs = 1 - ranks.pow(1 / float(n)) + to_permute = torch.rand_like(subpopulation) < permutation_probs + return torch.where(to_permute, permuted, subpopulation) + + +@expects_ndim(2) +def _cosyne_permutation_for_entire_population(population: torch.Tensor) -> torch.Tensor: + """ + Return the permuted (i.e. shuffled) version of a population. + + Shuffling of the values is done columnwise. + + Args: + population: A tensor with at least 2 dimensions, representing the + decision values of the solutions. + Returns: + The shuffled counterpart of the given population, as a new tensor. + """ + return _cosyne_permutation_for_entire_subpopulation(population.T).T + + +@expects_ndim(2, 1, None, randomness="different") +def _partial_cosyne_permutation_for_population( + population: torch.Tensor, evals: torch.Tensor, objective_sense: str +) -> torch.Tensor: + """ + Return the permuted (i.e. shuffled) version of a population. + + Shuffling of the values is done columnwise. For each column, while doing + the shuffling, each item is given a probability of staying the same. + This probability is higher for items that belong to solutions with better + fitnesses. + + Args: + population: A tensor with at least 2 dimensions, representing the + decision values of the solutions. Extra leftmost dimensions will + be considered as batch dimensions. + evals: Evaluation results (i.e. fitnesses), as a tensor with at least + one dimension. Extra leftmost dimensions will be considered as + batch dimensions. + objective_sense: A string whose value is either 'min' or 'max', + representing the goal of the optimization. + Returns: + The shuffled counterpart of the given population, as a new tensor. + """ + return _partial_cosyne_permutation_for_subpopulation(population.T, evals, objective_sense).T + + +def cosyne_permutation( + values: torch.Tensor, + evals: Optional[torch.Tensor] = None, + *, + permute_all: bool = True, + objective_sense: Optional[str] = None, +) -> torch.Tensor: + """ + Return the permuted (i.e. shuffled) version of the given decision values. + + Shuffling of the decision values is done columnwise. + + If `permute_all` is given as True, each item within each column will be + subject to permutation. In this mode, the arguments `evals` and + `objective_sense` can be omitted (i.e. can be left as None). + + If `permute_all` is given as False, each item within each column is given + a probability of staying the same. This probability is higher for items + that belong to solutions with better fitnesses. In this mode, the + arguments `evals` and `objective_sense` are mandatory. + + Reference: + + ``` + Gomez, F., Schmidhuber, J., Miikkulainen, R., & Mitchell, M. (2008). + Accelerated Neural Evolution through Cooperatively Coevolved Synapses. + Journal of Machine Learning Research, 9(5). + ``` + + Args: + population: A tensor with at least 2 dimensions, representing the + decision values of the solutions. Extra leftmost dimensions will + be considered as batch dimensions. + evals: Evaluation results (i.e. fitnesses), as a tensor with at least + one dimension. Extra leftmost dimensions will be considered as + batch dimensions. If `permute_all` is True, this argument can be + left as None. + permute_all: Whether or not each item within each column will be + subject to permutation operation. If given as False, items + with better fitnesses have greater probabilities of staying the + same. The default is True. + objective_sense: A string whose value is either 'min' or 'max', + representing the goal of the optimization. If `permute_all` is + True, this argument can be left as None. + Returns: + The shuffled counterpart of the given population, as a new tensor. + """ + if permute_all: + return _cosyne_permutation_for_entire_population(values) + else: + if evals is None: + raise ValueError("When `permute_all` is False, `evals` is required") + if objective_sense is None: + raise ValueError("When `permute_all` is False, `objective_sense` is required") + return _partial_cosyne_permutation_for_population(values, evals, objective_sense) + + +@expects_ndim(2, 2) +def _combine_values(values1: torch.Tensor, values2: torch.Tensor) -> torch.Tensor: + return torch.vstack([values1, values2]) + + +@expects_ndim(2, 1, 2, 1) +def _combine_values_and_evals( + values1: torch.Tensor, evals1: torch.Tensor, values2: torch.Tensor, evals2: torch.Tensor +) -> tuple: + return torch.vstack([values1, values2]), torch.hstack([evals1, evals2]) + + +def combine(a: Union[torch.Tensor, tuple], b: Union[torch.Tensor, tuple]) -> Union[torch.Tensor, tuple]: + """ + Combine two populations into one. + + This function can be used in two forms. + + **First usage: without evaluation results.** + Let us assume that we have two decision values matrices, `values1` + `values2`. The shapes of these matrices are (n1, L) and (n2, L) + respectively, where L represents the length of a solution. + Let us assume that the solutions that these decision values + represent are not evaluated yet. Therefore, we do not have evaluation + results (i.e. we do not have fitnesses). Two combine these two + unevaluated populations, we use this function as follows: + + ```python + combined_population = combine(values1, values2) + + # We now have a combined decision values matrix, shaped (n1+n2, L). + ``` + + **Second usage: with evaluation results.** + Let us now assume that we have two decision values matrices, `values1` + and `values2`. Like in our previous example, these matrices are shaped + (n1, L) and (n2, L), respectively. Additionally, let us assume that we + know the evaluation results for the solutions represented by `values1` + and `values2`. These evaluation results are represented by the tensors + `evals1` and `evals2`, shaped (n1,) and (n2,), respectively. Two + combine these two evaluated populations, we use this function as follows: + + ```python + c_values, c_evals = combine((values1, evals1), (values2, evals2)) + + # We now have a combined decision values matrix and a combined evaluations + # vector. + # `c_values` is shaped (n1+n2, L), and `c_evals` is shaped (n1+n2,). + ``` + + Args: + a: A decision values tensor with at least 2 dimensions, or a tuple + of the form `(values, evals)`, where `values` is an at least + 2-dimensional decision values tensor, and `evals` is an at least + 1-dimensional evaluation results tensor. + Extra leftmost dimensions are taken as batch dimensions. + If this positional argument is a tensor, the second positional + argument must also be a tensor. If this positional argument is a + tuple, the second positional argument must also be a tuple. + b: A decision values tensor with at least 2 dimensions, or a tuple + of the form `(values, evals)`, where `values` is an at least + 2-dimensional decision values tensor, and `evals` is an at least + 1-dimensional evaluation results tensor. + Extra leftmost dimensions are taken as batch dimensions. + If this positional argument is a tensor, the first positional + argument must also be a tensor. If this positional argument is a + tuple, the first positional argument must also be a tuple. + Returns: + The combined decision values tensor, or a tuple of the form + `(values, evals)` where `values` is the combined decision values + tensor, and `evals` is the combined evaluation results tensor. + """ + if isinstance(a, tuple): + values1, evals1 = a + if not isinstance(b, tuple): + raise TypeError( + "The first positional argument was received as a tuple." + " Therefore, the second positional argument was also expected as a tuple." + f" However, the second argument is {repr(b)} (of type {type(b)})." + ) + values2, evals2 = b + return _combine_values_and_evals(values1, evals1, values2, evals2) + elif isinstance(a, torch.Tensor): + if not isinstance(b, torch.Tensor): + raise TypeError( + "The first positional argument was received as a tensor." + " Therefore, the second positional argument was also expected as a tensor." + f" However, the second argument is {repr(b)} (of type {type(b)})." + ) + return _combine_values(a, b) + else: + raise TypeError( + "Expected both positional arguments as tensors, or as tuples." + f" However, the first positional argument is {repr(a)} (of type {type(a)})." + ) + + +@expects_ndim(2, 1, None) +def _take_single_best(values: torch.Tensor, evals: torch.Tensor, objective_sense: str) -> tuple: + if objective_sense == "min": + argfn = torch.argmin + elif objective_sense == "max": + argfn = torch.argmax + else: + raise ValueError( + f"`objective_sense` was expected as 'min' or 'max', but was received as {repr(objective_sense)}" + ) + + _, solution_length = values.shape + index_of_best = argfn(evals).reshape(1) + best_row = torch.index_select(values, 0, index_of_best).reshape(solution_length) + best_eval = torch.index_select(evals, 0, index_of_best).reshape(tuple()) + return best_row, best_eval + + +@expects_ndim(2, 1, None, None) +def _take_multiple_best(values: torch.Tensor, evals: torch.Tensor, n: int, objective_sense: str) -> tuple: + if objective_sense == "min": + descending = False + elif objective_sense == "max": + descending = True + else: + raise ValueError( + f"`objective_sense` was expected as 'min' or 'max', but was received as {repr(objective_sense)}" + ) + + indices_of_best = torch.argsort(evals, descending=descending)[:n] + best_rows = torch.index_select(values, 0, indices_of_best) + best_evals = torch.index_select(evals, 0, indices_of_best) + return best_rows, best_evals + + +def take_best(values: torch.Tensor, evals: torch.Tensor, n: Optional[int] = None, *, objective_sense: str) -> tuple: + """ + Take the best solution, or the best `n` number of solutions. + + Args: + values: Decision values tensor, with at least 2 dimensions. + Extra leftmost dimensions will be taken as batch dimensions. + evals: Evaluation results tensor, with at least 1 dimension. + Extra leftmost dimensions will be taken as batch dimensions. + n: If left as None, the single best solution will be taken. + If given as an integer, this number of best solutions will be + taken. + objective_sense: A string whose value is either 'min' or 'max', + representing the goal of the optimization. + """ + if n is None: + return _take_single_best(values, evals, objective_sense) + else: + return _take_multiple_best(values, evals, n, objective_sense) From 08bfa01377daa86bcc2c35b9ab740eabd7eb8bf2 Mon Sep 17 00:00:00 2001 From: Nihat Engin Toklu Date: Mon, 29 Jul 2024 19:44:27 +0200 Subject: [PATCH 2/6] Add multiobjective capabilities for functional operators --- examples/notebooks/Functional_API/README.md | 1 + .../Functional_API/multiobj_batched_ops.ipynb | 445 +++++++++++++ src/evotorch/operators/functional.py | 588 ++++++++++++++++-- 3 files changed, 987 insertions(+), 47 deletions(-) create mode 100644 examples/notebooks/Functional_API/multiobj_batched_ops.ipynb diff --git a/examples/notebooks/Functional_API/README.md b/examples/notebooks/Functional_API/README.md index 911338f9..d315c6f0 100644 --- a/examples/notebooks/Functional_API/README.md +++ b/examples/notebooks/Functional_API/README.md @@ -6,5 +6,6 @@ Here are the examples demonstrating various features of this functional API: - **[Maintaining a batch of populations using the functional EvoTorch API](batched_searches.ipynb)**: This notebook shows how one can efficiently run multiple searches simultaneously, each with its own population and hyperparameter configuration, by maintaining a batch of populations. - **[Functional genetic algorithm operators](functional_ops.ipynb)**: This notebook shows how one can implement a custom genetic algorithm by combining the genetic algorithm operator implementations provided by the functional API of EvoTorch. +- **[Functional operators for multi-objective optimization](multiobj_batched_ops.ipynb)**: This notebook shows how one can use the functional operators of EvoTorch for multi-objective optimization. Additionally, batched optimization capabilities of these operators are demonstrated. - **[Solving constrained optimization problems](constrained.ipynb)**: EvoTorch provides batching-friendly constraint penalization functions that can be used with both the object-oriented API and the functional API. In addition, these constraint penalization functions can be used with gradient-based optimization. This notebook demonstrates these features. - **[Solving reinforcement learning tasks using functional evolutionary algorithms](problem.ipynb)**: The functional evolutionary algorithm implementations of EvoTorch can be used to solve problems that are expressed using the object-oriented core API of EvoTorch. To demonstrate this, this notebook instantiates a `GymNE` problem for the reinforcement learning task "CartPole-v1", and solves it using the functional `pgpe` implementation. diff --git a/examples/notebooks/Functional_API/multiobj_batched_ops.ipynb b/examples/notebooks/Functional_API/multiobj_batched_ops.ipynb new file mode 100644 index 00000000..3c9e6ddc --- /dev/null +++ b/examples/notebooks/Functional_API/multiobj_batched_ops.ipynb @@ -0,0 +1,445 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3a0bd5c3-661c-4b5d-9960-6f679ac6e1c7", + "metadata": {}, + "source": [ + "# Multiobjective optimization via functional operators API\n", + "\n", + "The functional operators API of EvoTorch (`evotorch.operators.functional`) can be used for multiobjective optimization.\n", + "In this notebook, we demonstrate how this functional operators API can be used to tackle the Kursawe function, which has two objectives to be minimized." + ] + }, + { + "cell_type": "markdown", + "id": "f37a843a-5d3e-4bdd-a743-244c156d1408", + "metadata": {}, + "source": [ + "---\n", + "\n", + "We begin with the necessary imports:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fda7ae3a-970b-4803-a222-35cf0550bf19", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import evotorch.operators.functional as func_ops\n", + "from evotorch.decorators import rowwise\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "3019da0f-4545-4cbb-bf28-ee0fc5ea2f62", + "metadata": {}, + "source": [ + "Below, we implement Kursawe's function.\n", + "\n", + "Notice how we decorate the function via `evotorch.decorators.rowwise`. This `@rowwise` decorator allows us to implement the function `f` with the simple assumption that `x` is a single vector. However, when calling this decorated function `f` from outside, we will be able to provide `x` as a matrix, in which case the `@rowwise` decorator will broadcast the function `f` such that it will be applied for each row of the matrix. In fact, we can even give a tensor with 3 or more dimensions as `x`, and the decorated `f` will interpret all of the extra leftmost dimensions as the batch dimensions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0455f2df-c438-470c-a9e2-c35086c8eed3", + "metadata": {}, + "outputs": [], + "source": [ + "@rowwise\n", + "def f(x: torch.Tensor) -> torch.Tensor:\n", + " # Kursawe's function\n", + "\n", + " f1 = torch.sum(\n", + " -10 * torch.exp(\n", + " -0.2 * torch.sqrt(x[0:2] ** 2.0 + x[1:3] ** 2.0)\n", + " ),\n", + " )\n", + "\n", + " f2 = torch.sum(\n", + " (torch.abs(x) ** 0.8) + (5 * torch.sin(x ** 3)),\n", + " )\n", + " fitnesses = torch.hstack([f1, f2])\n", + " return fitnesses" + ] + }, + { + "cell_type": "markdown", + "id": "9b904d3b-8008-4782-9d8d-b9be3f7ec770", + "metadata": {}, + "source": [ + "Below, we have the constants regarding the problem, and hyperparameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c634129-8b1e-4050-a618-d5cce7b6843f", + "metadata": {}, + "outputs": [], + "source": [ + "device = \"cpu\"\n", + "solution_length = 3\n", + "objective_sense = [\"min\", \"min\"]\n", + "lb = -5.0\n", + "ub = 5.0\n", + "\n", + "popsize = 200\n", + "num_generations = 100\n", + "mutation_stdev = 0.03\n", + "tournament_size = 4" + ] + }, + { + "cell_type": "markdown", + "id": "b29688fc-73ae-4a07-8377-b2c7ddeb9640", + "metadata": {}, + "source": [ + "Initialize a population, and store it via the variable `population`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab53d77a-c30a-4ed3-ab05-6969a9933280", + "metadata": {}, + "outputs": [], + "source": [ + "population = (torch.rand(popsize, solution_length, device=device) * (ub - lb)) + lb\n", + "population.shape" + ] + }, + { + "cell_type": "markdown", + "id": "4c7b7c69-d728-4a29-8b5e-6dbb26b4f7b8", + "metadata": {}, + "source": [ + "Evaluate the initial population, and store the evaluation results within the variable `evals`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3d63e29-6849-4593-bc18-49d200a1298a", + "metadata": {}, + "outputs": [], + "source": [ + "evals = f(population)\n", + "evals.shape" + ] + }, + { + "cell_type": "markdown", + "id": "8022fda7-0d73-432e-afe3-d5c1c2ca7d12", + "metadata": {}, + "source": [ + "Main loop of the optimization:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "378e59ee-31f7-41e0-b3e1-a9059e4c6aef", + "metadata": {}, + "outputs": [], + "source": [ + "for generation in range(1, 1 + num_generations):\n", + "\n", + " # Apply a tournament selection, and a simulated binary cross-over (SBX)\n", + " # on the selected parents\n", + " candidates = func_ops.simulated_binary_cross_over(\n", + " population,\n", + " evals,\n", + " tournament_size=tournament_size,\n", + " eta=1,\n", + " objective_sense=objective_sense,\n", + " )\n", + "\n", + " # Instead of a simulated binary cross-over, we could also use a two-point\n", + " # cross-over, as follows:\n", + " #\n", + " # candidates = func_ops.two_point_cross_over(\n", + " # population,\n", + " # evals,\n", + " # tournament_size=tournament_size,\n", + " # objective_sense=objective_sense,\n", + " # )\n", + "\n", + " # Apply Gaussian mutation on the results of the cross-over operation\n", + " candidates = candidates + (torch.randn_like(candidates) * mutation_stdev)\n", + "\n", + " # Evaluate the mutated candidate solutions\n", + " candidate_evals = f(candidates)\n", + "\n", + " # Form an extended population by combining the parent solutions and the\n", + " # candidate solutions.\n", + " extended_population, extended_evals = func_ops.combine(\n", + " (population, evals),\n", + " (candidates, candidate_evals),\n", + " # We are passing `objective_sense` to inform the `combine` function\n", + " # that the problem at hand is multi-objective:\n", + " objective_sense=objective_sense,\n", + " )\n", + "\n", + " # Take the `popsize` number of solutions from best pareto-fronts.\n", + " population, evals = func_ops.take_best(\n", + " extended_population,\n", + " extended_evals,\n", + " popsize,\n", + " objective_sense=objective_sense,\n", + " # When selecting the solutions, we want the crowding distances of the\n", + " # solutions to be taken into account:\n", + " crowdsort=True\n", + " )\n", + "\n", + " # Print the current status:\n", + " print(\"Generation:\", generation, \" Best evals of the population:\", torch.max(evals, dim=0).values)" + ] + }, + { + "cell_type": "markdown", + "id": "7fae1133-c0a5-4867-ad65-58bd4d25cb77", + "metadata": {}, + "source": [ + "Considering that `evals` now stores the evaluation results of the latest population, we can take the best solutions belonging to the best pareto-front as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a26ede1-490d-4925-914a-475c731b119c", + "metadata": {}, + "outputs": [], + "source": [ + "# Compute domination count (i.e. how many times a solution was dominated)\n", + "# for each solution\n", + "dcounts = func_ops.domination_counts(evals, objective_sense=objective_sense)\n", + "\n", + "# Make a mask in which the i-th element is True if the i-th solution of the\n", + "# population has never been dominated\n", + "# (i.e. if the i-th solution is at the best pareto-front)\n", + "at_best_front = (dcounts == 0)\n", + "\n", + "# Filter both the decision values tensor and the evaluation results tensor\n", + "# such that only the solutions on the best pareto-front will be included.\n", + "# The results of this filtering operation will be stored by the variables\n", + "# `best_pop` and `best_evals`.\n", + "best_pop = population[at_best_front]\n", + "best_evals = evals[at_best_front]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4349eb26-1d33-403b-8f59-b9a3c506f4dc", + "metadata": {}, + "outputs": [], + "source": [ + "best_pop.shape, best_evals.shape" + ] + }, + { + "cell_type": "markdown", + "id": "6027aa8a-cc0a-433a-b0de-2956137becaa", + "metadata": {}, + "source": [ + "Plot the fitnesses of the best solutions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2e29903-a3e5-4f2a-9f35-2e73fd5988ec", + "metadata": {}, + "outputs": [], + "source": [ + "plt.title(f\"Fitnesses after {num_generations} generations\")\n", + "plt.scatter(best_evals[:, 0].cpu().numpy(), best_evals[:, 1].cpu().numpy())\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6d5be444-bad7-42eb-be8f-3bdb57f358c0", + "metadata": {}, + "source": [ + "---\n", + "\n", + "# Batched multiobjective optimization\n", + "\n", + "The functional operators API of EvoTorch is written in such a way that a single call to an operator can work on not just a single population, but on a batch of multiple populations, in a vectorized manner.\n", + "\n", + "Below, we demonstrate this feature by modifying the multiobjective example above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f38b144-9bf3-4d4e-a266-14efbe909905", + "metadata": {}, + "outputs": [], + "source": [ + "# Let us consider 4 populations:\n", + "num_populations = 4\n", + "\n", + "# Size for each population:\n", + "popsize = 200\n", + "\n", + "# Shared hyperparameters\n", + "num_generations = 30\n", + "tournament_size = 4\n", + "mutation_stdev = 0.03\n", + "\n", + "# Hyperparameters that vary for each population:\n", + "eta = torch.tensor([1.0, 8.0, 20.0, 40.0], device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c82431e-e9f4-4d6e-9dba-b4a711139480", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize a batch of populations (each population is initialized the same)\n", + "population = (torch.rand(popsize, solution_length, device=device) * (ub - lb)) + lb\n", + "evals = f(population)\n", + "broadcaster = torch.ones(num_populations, 1, 1, device=device)\n", + "population = population * broadcaster\n", + "evals = evals * broadcaster\n", + "\n", + "# Alternatively, in some cases, you might want to do the initialization in\n", + "# such a way that each population within the batch is different:\n", + "#\n", + "# population = (torch.rand(num_populations, popsize, solution_length, device=device) * (ub - lb)) + lb\n", + "# evals = f(population)\n", + "\n", + "population.shape, evals.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24f5a201-7d33-4bee-a569-dbfa97d0efd3", + "metadata": {}, + "outputs": [], + "source": [ + "for generation in range(1, 1 + num_generations):\n", + " candidates = func_ops.simulated_binary_cross_over(\n", + " population,\n", + " evals,\n", + " tournament_size=tournament_size,\n", + " objective_sense=objective_sense,\n", + " #\n", + " # Upon seeing that `eta` is given as a vector (instead of a scalar),\n", + " # the function `simulated_binary_cross_over` will treat the `eta` as\n", + " # a batch of hyperparameters. In more details, the first `eta` is\n", + " # used on the first population of the batch, the second `eta` is used\n", + " # on the second population of the batch, and so on...\n", + " eta=eta,\n", + " )\n", + "\n", + " candidates = candidates + (torch.randn_like(candidates) * mutation_stdev)\n", + " candidate_evals = f(candidates)\n", + "\n", + " extended_population, extended_evals = func_ops.combine(\n", + " (population, evals),\n", + " (candidates, candidate_evals),\n", + " objective_sense=objective_sense,\n", + " )\n", + "\n", + " population, evals = func_ops.take_best(\n", + " extended_population,\n", + " extended_evals,\n", + " popsize,\n", + " objective_sense=objective_sense,\n", + " crowdsort=True\n", + " )\n", + "\n", + " # Print the current status:\n", + " print(\"Generation:\", generation)\n", + " print(\"Best evals of the populations:\")\n", + " print(torch.max(evals, dim=1).values)\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "id": "66418331-30da-48b0-a2ea-425ba8b30730", + "metadata": {}, + "source": [ + "For each solution within each population, compute the domination count:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f7cd61d-0396-4e52-a4a2-e793791991c7", + "metadata": {}, + "outputs": [], + "source": [ + "dcounts = func_ops.domination_counts(evals, objective_sense=objective_sense)\n", + "dcounts.shape" + ] + }, + { + "cell_type": "markdown", + "id": "403568ab-5463-4c4a-bccc-3dd918eb60a5", + "metadata": {}, + "source": [ + "From each population, take the best pareto-front, and plot the fitnesses belonging to that pareto-front:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccc9ffcc-06ff-44b0-998f-5874468c2383", + "metadata": {}, + "outputs": [], + "source": [ + "for i_population in range(num_populations):\n", + " single_pop_dcounts = dcounts[i_population, :]\n", + " single_pop = population[i_population, :, :]\n", + " single_pop_evals = evals[i_population, :, :]\n", + "\n", + " at_best_front = (single_pop_dcounts == 0)\n", + "\n", + " best_pop = single_pop[at_best_front]\n", + " best_evals = single_pop_evals[at_best_front]\n", + "\n", + " plt.title(f\"Fitnesses with eta={float(eta[i_population].cpu())}, after {num_generations} generations\")\n", + " plt.scatter(best_evals[:, 0].cpu().numpy(), best_evals[:, 1].cpu().numpy())\n", + " plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/evotorch/operators/functional.py b/src/evotorch/operators/functional.py index 01004d2b..c089b6c4 100644 --- a/src/evotorch/operators/functional.py +++ b/src/evotorch/operators/functional.py @@ -133,13 +133,326 @@ def f(x: torch.Tensor) -> torch.Tensor: """ -from typing import Optional, Union +from typing import Iterable, Optional, Union import torch from evotorch.decorators import expects_ndim +def _index_comparison_matrices(n: int, *, device: Union[str, torch.dtype]) -> tuple: + """ + Return index tensors that are meant for pairwise comparisons. + + In more details, suppose that the argument `n` is given as 4. + What is returned by this function is a 3-element tuple of the form + `(indices_matrix1, indices_matrix2, index_row)`. In this returned + tuple, `indices_matrix1` is: + + ``` + 0 0 0 0 + 1 1 1 1 + 2 2 2 2 + 3 3 3 3 + ``` + + `indices_matrix2` is: + + ``` + 0 1 2 3 + 0 1 2 3 + 0 1 2 3 + 0 1 2 3 + ``` + + `index_row` is: + + ``` + 0 1 2 3 + ``` + + Note: `indices_matrix1` and `indices_matrix2` are expanded views to the + tensor `index_row`. Do not mutate any of these returned tensors, because + such mutations might probably reflect on all of them in unexpected ways. + + Args: + n: Size for the index row and matrices + device: The device in which the index tensors will be generated + Returns: + A tuple of the form `(indices_matrix1, indices_matrix2, index_row)` + where each item is a PyTorch tensor. + """ + increasing_indices = torch.arange(n, device=device) + indices1 = increasing_indices.reshape(n, 1).expand(n, n) + indices2 = increasing_indices.reshape(1, n).expand(n, n) + return indices1, indices2, increasing_indices + + +@expects_ndim(1, 1, None) +def _dominates( + evals1: torch.Tensor, + evals2: torch.Tensor, + objective_sense: list, +) -> torch.Tensor: + [num_objs] = evals1.shape + [n2] = evals2.shape + if num_objs != n2: + raise ValueError("The lengths of the evaluation results vectors do not match.") + if num_objs != len(objective_sense): + raise ValueError("The lengths of the evaluation results vectors do not match the number of objectives") + + # For easier internal representation, we generate a sign adjustment tensor. + # The motivation is to be able to multiply the evaluation tensors with this adjustment tensor, + # resulting in new evaluation tensors that guarantee that better results are higher values. + dtype = evals1.dtype + device = evals1.device + sign_adjustment = torch.empty(num_objs, dtype=dtype, device=device) + for i_obj, obj in enumerate(objective_sense): + if obj == "min": + sign_adjustment[i_obj] = -1 + elif obj == "max": + sign_adjustment[i_obj] = 1 + else: + raise ValueError( + "`objective_sense` was expected as a list that consists only of the strings 'min' or 'max'." + f" However, one of the items encountered within `objective_sense` is: {repr(obj)}." + ) + + # Adjust the signs of the evaluation tensors + evals1 = sign_adjustment * evals1 + evals2 = sign_adjustment * evals2 + + # Count the number of victories for each solution + num_victories_of_first = (evals1 > evals2).to(dtype=torch.int64).sum() + num_victories_of_second = (evals2 > evals1).to(dtype=torch.int64).sum() + + # If the first solution has won at least 1 time, and the second solution never won, we can say that the + # first solution pareto-dominates the second one. + return (num_victories_of_first >= 1) & (num_victories_of_second == 0) + + +def dominates( + evals1: torch.Tensor, + evals2: torch.Tensor, + *, + objective_sense: list, +) -> torch.Tensor: + """ + Return whether or not the first solution pareto-dominates the second one. + + Args: + evals1: Evaluation results of the first solution. Expected as an + at-least-1-dimensional tensor, the length of which must be + equal to the number of objectives. Extra leftmost dimensions + will be considered as batch dimensions. + evals2: Evaluation results of the second solution. Expected as an + at-least-1-dimensional tensor, the length of which must be + equal to the number of objectives. Extra leftmost dimensions + will be considered as batch dimensions. + objective_sense: Expected as a list of strings, where each + string is either 'min' or 'max', expressing the direction of + optimization regarding each objective. + Returns: + A tensor of boolean(s), indicating whether or not the first + solution(s) dominate(s) the second solution(s). + """ + if isinstance(objective_sense, str): + raise ValueError( + "`objective_sense` was received as a string, implying that the problem at hand has a single objective." + " However, this `dominates(...)` function does not support single-objective cases." + ) + elif isinstance(objective_sense, Iterable): + return _dominates(evals1, evals2, objective_sense) + else: + raise TypeError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") + + +@expects_ndim(2, 0, 0, None) +def _domination_check_via_indices( + population_evals: torch.Tensor, + solution1_index: torch.Tensor, + solution2_index: torch.Tensor, + objective_sense: list, +) -> torch.Tensor: + evals1 = torch.index_select(population_evals, 0, solution1_index.reshape(1))[0] + evals2 = torch.index_select(population_evals, 0, solution2_index.reshape(1))[0] + return _dominates(evals1, evals2, objective_sense) + + +@expects_ndim(2, None) +def _domination_matrix( + evals: torch.Tensor, + objective_sense: list, +) -> torch.Tensor: + num_solutions, _ = evals.shape + indices1, indices2, _ = _index_comparison_matrices(num_solutions, device=evals.device) + return _domination_check_via_indices(evals, indices2, indices1, objective_sense) + + +def domination_matrix(evals: torch.Tensor, *, objective_sense: list) -> torch.Tensor: + """ + Compute and return a pareto-domination matrix. + + In this pareto-domination matrix `P`, the item `P[i,j]` is True if the + `i`-th solution is dominated by the `j`-th solution. + + Args: + evals: Evaluation results of the solutions, expected as a tensor + with at least 2 dimensions. In a 2-dimensional `evals` tensor, + the item `i,j` represents the evaluation result of the + `i`-th solution according to the `j`-th objective. + Extra leftmost dimensions are interpreted as batch dimensions. + objective_sense: A list of strings, where each string is either + 'min' or 'max', expressing the direction of optimization regarding + each objective. + Returns: + A boolean tensor of size `(n,n)`, where `n` is the number of solutions. + """ + return _domination_matrix(evals, objective_sense) + + +@expects_ndim(2, None) +def _domination_counts(evals: torch.Tensor, objective_sense: list) -> torch.Tensor: + return _domination_matrix(evals, objective_sense).to(dtype=torch.int64).sum(dim=-1) + + +def domination_counts(evals: torch.Tensor, *, objective_sense: list) -> torch.Tensor: + """ + Return a tensor expressing how many times each solution gets dominated + + In this returned tensor, the `i`-th item is an integer which specifies how + many times the `i`-th solution is dominated. + + Args: + evals: Expected as an at-least-2-dimensional tensor. In such a + 2-dimensional evaluation tensor, the item `i,j` represents the + evaluation result of the `i`-th solution according to the `j`-th + objective. Extra leftmost dimensions are interpreted as batch + dimensions. + objective_sense: A list of strings, where each string is either + 'min' or 'max', expressing the direction of optimization regarding + each objective. + Returns: + An integer tensor of length `n`, where `n` is the number of solutions. + """ + return _domination_counts(evals, objective_sense) + + +@expects_ndim(2, 1, 0, 0) +def _crowding_distance_of_solution_considering_objective( + population_evals: torch.Tensor, + domination_counts: torch.Tensor, + solution_index: torch.Tensor, + objective_index: torch.Tensor, +) -> torch.Tensor: + num_solutions, _ = population_evals.shape + + [num_domination_counts] = domination_counts.shape + if num_domination_counts != num_solutions: + raise ValueError( + "The number of solutions stored within `evals` does not match the length of `domination_counts`." + ) + + # Get the evaluation results vector for the considered objective + eval_vector = torch.index_select(population_evals, -1, objective_index.reshape(1)).reshape(num_solutions) + + # Get the evaluation result and the domination count for the considered solution + solution_eval = torch.index_select(eval_vector, 0, solution_index.reshape(1))[0] + solution_domination_count = torch.index_select(domination_counts, 0, solution_index.reshape(1))[0] + + # Prepare the masks `got_lower_eval` and `got_higher_eval`. These masks store True for any solution in the + # same pareto-front with lower evaluation result, and with higher evaluation result, respectively. + within_same_front = domination_counts == solution_domination_count + got_lower_eval = within_same_front & (eval_vector < solution_eval) + got_higher_eval = within_same_front & (eval_vector > solution_eval) + + # Compute a large-enough constant that will be the crowding distance for when the considered solution is + # pareto-extreme + large_constant = 2 * (eval_vector.max() - eval_vector.min()) + + # For each solution within the same pareto-front with lower evaluation result, compute the fitness distance + distances_from_below = torch.where(got_lower_eval, solution_eval - eval_vector, large_constant) + # For each solution within the same pareto-front with higher evaluation result, compute the fitness distance + distances_from_above = torch.where(got_higher_eval, eval_vector - solution_eval, large_constant) + + # Sum of the nearest (min) distance from below and the nearest (min) distance from above is the crowding distance + # for the considered objective. + return distances_from_below.min() + distances_from_above.min() + + +@expects_ndim(2, 1, 0) +def _crowding_distance_of_solution( + population_evals: torch.Tensor, + domination_counts: torch.Tensor, + solution_index: torch.Tensor, +) -> torch.Tensor: + _, num_objectives = population_evals.shape + objective_indices = torch.arange(num_objectives, dtype=torch.int64, device=population_evals.device) + + # Compute the crowding distances for all objectives, then sum those distances, then return the result. + return _crowding_distance_of_solution_considering_objective( + population_evals, domination_counts, solution_index, objective_indices + ).sum() + + +@expects_ndim(2, 1) +def _crowding_distances(population_evals: torch.Tensor, domination_counts: torch.Tensor) -> torch.Tensor: + num_solutions, _ = population_evals.shape + all_solution_indices = torch.arange(num_solutions, dtype=torch.int64, device=population_evals.device) + return _crowding_distance_of_solution(population_evals, domination_counts, all_solution_indices) + + +@expects_ndim(2, None, None) +def _pareto_utility(evals: torch.Tensor, objective_sense: list, crowdsort: bool) -> torch.Tensor: + num_solutions, _ = evals.shape + domination_counts = _domination_counts(evals, objective_sense) + + # Compute utility values such that a solution that has less domination counts (i.e. a solution that has been + # dominated less) will have a higher utility value. + result = torch.as_tensor(num_solutions - domination_counts, dtype=evals.dtype) + + if crowdsort: + # Compute the crowding distances + distances = _crowding_distances(evals, domination_counts) + # Rescale the crowding distances so that they are between 0 and 0.99 + min_distance = distances.min() + max_distance = distances.max() + distance_range = (max_distance - min_distance) + 1e-8 + rescaled_distances = 0.99 * ((distances - min_distance) / distance_range) + # Add the rescaled distances to the resulting utility values + result = result + rescaled_distances + + return result + + +def pareto_utility(evals: torch.Tensor, *, objective_sense: list, crowdsort: bool = True) -> torch.Tensor: + """ + Compute utility values for the solutions of a multi-objective problem. + + A solution on a better pareto-front is assigned a higher utility value. + Additionally, if `crowdsort` is given as True crowding distances will also + be taken into account. In more details, in the same pareto-front, + solutions with higher crowding distances will have increased utility + values. + + Args: + evals: Evaluation results, expected as a tensor with at least two + dimensions. A 2-dimensional `evals` tensor is expected to be + shaped as (numberOfSolutions, numberOfObjectives). Extra + leftmost dimensions will be interpreted as batch dimensions. + objective_sense: Expected as a list of strings, where each string + is either 'min' or 'max'. The i-th item within this list + represents the direction of the optimization for the i-th + objective. + Returns: + A utility tensor. Considering the non-batched case (i.e. considering + that `evals` was given as a 2-dimensional tensor), the i-th item + within the returned utility tensor represents the utility value + assigned to the i-th solution. + """ + return _pareto_utility(evals, objective_sense, crowdsort) + + @expects_ndim(2, 1, 1, None, randomness="different") def _pick_solution_via_tournament( solutions: torch.Tensor, @@ -199,24 +512,52 @@ def _pick_solution_via_tournament( @expects_ndim(2, 1, None, None, None, randomness="different") -def _tournament( +def _single_objective_tournament( solutions: torch.Tensor, evals: torch.Tensor, num_tournaments: int, tournament_size: int, objective_sense: str, +) -> tuple: + if tournament_size < 1: + raise ValueError( + "The argument `tournament_size` was expected to be greater than or equal to 1." + f" However, it was encountered as {tournament_size}." + ) + popsize, _ = solutions.shape + indices_for_tournament = torch.randint_like( + solutions[:1, :1].expand(num_tournaments, tournament_size), 0, popsize, dtype=torch.int64 + ) + return _pick_solution_via_tournament(solutions, evals, indices_for_tournament, objective_sense) + + +def _tournament( + solutions: torch.Tensor, + evals: torch.Tensor, + num_tournaments: int, + tournament_size: int, + objective_sense: Union[str, list], ) -> tuple: """ Randomly pick solutions, put them into a tournament, pick the winners. Args: - solutions: Decision values of the solutions - evals: Evaluation results of the solutions + solutions: Decision values of the solutions. + evals: Evaluation results of the solutions. + In the single-objective case, this is expected as an + at-least-1-dimensional tensor, the `i`-th item expressing + the evaluation result of the `i`-th solution. + In the multi-objective case, this is expected as an + at-least-2-dimensional tensor, the `(i,j)`-th item + expressing the evaluation result of the `i`-th solution + according to the `j`-th objective. + Extra leftmost dimensions are interpreted as batch dimensions. num_tournaments: Number of tournaments that will be applied. In other words, number of winners that will be picked. tournament_size: Number of solutions to be picked for the tournament - objective_sense: A string of value 'min' or 'max', representing the - goal of the optimization + objective_sense: A string or a list of strings, where (each) string + has either the value 'min' for minimization or 'max' for + maximization. Returns: A tuple of the form `(decision_values, eval_results)` where `decision_values` is the tensor that contains the decision values @@ -224,16 +565,19 @@ def _tournament( contains the evaluation results (i.e. fitnesses) of the winning solutions. """ - if tournament_size < 1: - raise ValueError( - "The argument `tournament_size` was expected to be greater than or equal to 1." - f" However, it was encountered as {tournament_size}." + if isinstance(objective_sense, str): + pass # nothing to do + elif isinstance(objective_sense, Iterable): + objective_sense = list(objective_sense) + evals = pareto_utility(evals, objective_sense=objective_sense, crowdsort=False) + objective_sense = "max" + else: + raise TypeError( + "The argument `objective_sense` was expected as a string for the single-objective case," + " or as a list of strings for the multi-objective case." + f" However, the encountered `objective_sense` is {repr(objective_sense)}." ) - popsize, _ = solutions.shape - indices_for_tournament = torch.randint_like( - solutions[:1, :1].expand(num_tournaments, tournament_size), 0, popsize, dtype=torch.int64 - ) - return _pick_solution_via_tournament(solutions, evals, indices_for_tournament, objective_sense) + return _single_objective_tournament(solutions, evals, num_tournaments, tournament_size, objective_sense) @expects_ndim(2, randomness="different") @@ -342,7 +686,7 @@ def multi_point_cross_over( num_points: int, tournament_size: Optional[int] = None, num_children: Optional[int] = None, - objective_sense: Optional[str] = None, + objective_sense: Optional[Union[str, list]] = None, ) -> torch.Tensor: """ Apply multi-point cross-over on the given `parents`. @@ -386,10 +730,14 @@ def multi_point_cross_over( If there is no tournament selection (i.e. if `tournament_size` is None), `num_children` is expected to be None. objective_sense: Mandatory if `tournament_size` is not None. - `objective_sense` is expected as 'max' for when the goal of the - optimization is to maximize `evals`, 'min' for when the goal of - the optimization is to minimize `evals`. If `tournament_size` - is None, `objective_sense` can also be left as None. + For when there is only one objective, `objective_sense` is + expected as 'max' for when the goal of the optimization is to + maximize `evals`, 'min' for when the goal of the optimization is + to minimize `evals`. For when there are multiple objectives, + `objective_sense` is expected as a list of strings, where each + string is either 'min' or 'max'. + If `tournament_size` is None, `objective_sense` can also be left + as None. Returns: Decision values of the child solutions, as a new tensor. """ @@ -502,10 +850,14 @@ def one_point_cross_over( If there is no tournament selection (i.e. if `tournament_size` is None), `num_children` is expected to be None. objective_sense: Mandatory if `tournament_size` is not None. - `objective_sense` is expected as 'max' for when the goal of the - optimization is to maximize `evals`, 'min' for when the goal of - the optimization is to minimize `evals`. If `tournament_size` - is None, `objective_sense` can also be left as None. + For when there is only one objective, `objective_sense` is + expected as 'max' for when the goal of the optimization is to + maximize `evals`, 'min' for when the goal of the optimization is + to minimize `evals`. For when there are multiple objectives, + `objective_sense` is expected as a list of strings, where each + string is either 'min' or 'max'. + If `tournament_size` is None, `objective_sense` can also be left + as None. Returns: Decision values of the child solutions, as a new tensor. """ @@ -596,10 +948,14 @@ def two_point_cross_over( If there is no tournament selection (i.e. if `tournament_size` is None), `num_children` is expected to be None. objective_sense: Mandatory if `tournament_size` is not None. - `objective_sense` is expected as 'max' for when the goal of the - optimization is to maximize `evals`, 'min' for when the goal of - the optimization is to minimize `evals`. If `tournament_size` - is None, `objective_sense` can also be left as None. + For when there is only one objective, `objective_sense` is + expected as 'max' for when the goal of the optimization is to + maximize `evals`, 'min' for when the goal of the optimization is + to minimize `evals`. For when there are multiple objectives, + `objective_sense` is expected as a list of strings, where each + string is either 'min' or 'max'. + If `tournament_size` is None, `objective_sense` can also be left + as None. Returns: Decision values of the child solutions, as a new tensor. """ @@ -688,10 +1044,14 @@ def simulated_binary_cross_over( If there is no tournament selection (i.e. if `tournament_size` is None), `num_children` is expected to be None. objective_sense: Mandatory if `tournament_size` is not None. - `objective_sense` is expected as 'max' for when the goal of the - optimization is to maximize `evals`, 'min' for when the goal of - the optimization is to minimize `evals`. If `tournament_size` - is None, `objective_sense` can also be left as None. + For when there is only one objective, `objective_sense` is + expected as 'max' for when the goal of the optimization is to + maximize `evals`, 'min' for when the goal of the optimization is + to minimize `evals`. For when there are multiple objectives, + `objective_sense` is expected as a list of strings, where each + string is either 'min' or 'max'. + If `tournament_size` is None, `objective_sense` can also be left + as None. Returns: Decision values of the child solutions, as a new tensor. """ @@ -794,7 +1154,12 @@ def _utility(evals: torch.Tensor, objective_sense: str, ranking_method: Optional return ranks -def utility(evals: torch.Tensor, *, objective_sense: str, ranking_method: Optional[str] = "centered") -> torch.Tensor: +def utility( + evals: torch.Tensor, + *, + objective_sense: str, + ranking_method: Optional[str] = "centered", +) -> torch.Tensor: """ Return utility values representing how good the evaluation results are. @@ -832,7 +1197,17 @@ def utility(evals: torch.Tensor, *, objective_sense: str, ranking_method: Option Utility values, as a tensor whose shape is the same with the shape of `evals`. """ - return _utility(evals, objective_sense, ranking_method) + if isinstance(objective_sense, str): + return _utility(evals, objective_sense, ranking_method) + elif isinstance(objective_sense, Iterable): + raise ValueError( + "The argument `objective_sense` was received as an iterable other than string," + " implying that the problem at hand has multiple objectives." + " However, this `utility(...)` function does not support multiple objectives." + " Consider using `pareto_utility(...)`." + ) + else: + raise TypeError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") @expects_ndim(1, randomness="different") @@ -1005,19 +1380,31 @@ def _combine_values_and_evals( return torch.vstack([values1, values2]), torch.hstack([evals1, evals2]) -def combine(a: Union[torch.Tensor, tuple], b: Union[torch.Tensor, tuple]) -> Union[torch.Tensor, tuple]: +@expects_ndim(2, 2, 2, 2) +def _combine_values_and_multiobjective_evals( + values1: torch.Tensor, evals1: torch.Tensor, values2: torch.Tensor, evals2: torch.Tensor +) -> tuple: + return torch.vstack([values1, values2]), torch.vstack([evals1, evals2]) + + +def combine( + a: Union[torch.Tensor, tuple], + b: Union[torch.Tensor, tuple], + *, + objective_sense: Optional[Union[str, Iterable]] = None, +) -> Union[torch.Tensor, tuple]: """ Combine two populations into one. This function can be used in two forms. - **First usage: without evaluation results.** + **Usage 1: without evaluation results.** Let us assume that we have two decision values matrices, `values1` `values2`. The shapes of these matrices are (n1, L) and (n2, L) respectively, where L represents the length of a solution. Let us assume that the solutions that these decision values represent are not evaluated yet. Therefore, we do not have evaluation - results (i.e. we do not have fitnesses). Two combine these two + results (i.e. we do not have fitnesses). To combine these two unevaluated populations, we use this function as follows: ```python @@ -1026,14 +1413,14 @@ def combine(a: Union[torch.Tensor, tuple], b: Union[torch.Tensor, tuple]) -> Uni # We now have a combined decision values matrix, shaped (n1+n2, L). ``` - **Second usage: with evaluation results.** - Let us now assume that we have two decision values matrices, `values1` + **Usage 2: with evaluation results, single-objective.** + We again assume that we have two decision values matrices, `values1` and `values2`. Like in our previous example, these matrices are shaped (n1, L) and (n2, L), respectively. Additionally, let us assume that we know the evaluation results for the solutions represented by `values1` and `values2`. These evaluation results are represented by the tensors - `evals1` and `evals2`, shaped (n1,) and (n2,), respectively. Two - combine these two evaluated populations, we use this function as follows: + `evals1` and `evals2`, shaped (n1,) and (n2,), respectively. To combine + these two evaluated populations, we use this function as follows: ```python c_values, c_evals = combine((values1, evals1), (values2, evals2)) @@ -1043,6 +1430,27 @@ def combine(a: Union[torch.Tensor, tuple], b: Union[torch.Tensor, tuple]) -> Uni # `c_values` is shaped (n1+n2, L), and `c_evals` is shaped (n1+n2,). ``` + **Usage 3: with evaluation results, multi-objective.** + We again assume that we have two decision values matrices, `values1` + and `values2`. Like in our previous example, these matrices are shaped + (n1, L) and (n2, L), respectively. Additionally, we assume that we know + the evaluation results for these solutions. The evaluation results are + stored within the tensors `evals1` and `evals2`, whose shapes are + (n1, M) and (n2, M), where M is the number of objectives. To combine + these two evaluated populations, we use this function as follows: + + ```python + c_values, c_evals = combine( + (values1, evals1), + (values2, evals2), + objective_sense=["min", "min"], # Assuming we have 2 min objectives + ) + + # We now have a combined decision values matrix and a combined evaluations + # vector. + # `c_values` is shaped (n1+n2, L), and `c_evals` is shaped (n1+n2,). + ``` + Args: a: A decision values tensor with at least 2 dimensions, or a tuple of the form `(values, evals)`, where `values` is an at least @@ -1060,6 +1468,17 @@ def combine(a: Union[torch.Tensor, tuple], b: Union[torch.Tensor, tuple]) -> Uni If this positional argument is a tensor, the first positional argument must also be a tensor. If this positional argument is a tuple, the first positional argument must also be a tuple. + objective_sense: In the case of single-objective optimization, + `objective_sense` can be left as None, or can be 'min' or 'max', + representing the direction of the optimization. + In the case of multi-objective optimization, `objective_sense` + is expected as a list of strings, each string being 'min' or + 'max', representing the direction for each objective. + Please also note that, if this combination operation is done + without evaluation results (i.e. if the first two positional + arguments are given as tensors, not tuples), `objective_sense` + is not needed, and can be omitted, regardless of whether or + not the problem at hand is single-objective. Returns: The combined decision values tensor, or a tuple of the form `(values, evals)` where `values` is the combined decision values @@ -1074,7 +1493,12 @@ def combine(a: Union[torch.Tensor, tuple], b: Union[torch.Tensor, tuple]) -> Uni f" However, the second argument is {repr(b)} (of type {type(b)})." ) values2, evals2 = b - return _combine_values_and_evals(values1, evals1, values2, evals2) + if (objective_sense is None) or isinstance(objective_sense, str): + return _combine_values_and_evals(values1, evals1, values2, evals2) + elif isinstance(objective_sense, Iterable): + return _combine_values_and_multiobjective_evals(values1, evals1, values2, evals2) + else: + raise TypeError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") elif isinstance(a, torch.Tensor): if not isinstance(b, torch.Tensor): raise TypeError( @@ -1125,10 +1549,48 @@ def _take_multiple_best(values: torch.Tensor, evals: torch.Tensor, n: int, objec return best_rows, best_evals -def take_best(values: torch.Tensor, evals: torch.Tensor, n: Optional[int] = None, *, objective_sense: str) -> tuple: +@expects_ndim(2, 2, None, None, None) +def _take_multiple_best_with_multiobjective( + values: torch.Tensor, + evals: torch.Tensor, + n: int, + objective_sense: str, + crowdsort: bool, +) -> tuple: + utils = pareto_utility(evals, objective_sense=objective_sense, crowdsort=crowdsort) + indices_of_best = torch.argsort(utils, descending=True)[:n] + best_rows = torch.index_select(values, 0, indices_of_best) + best_evals = torch.index_select(evals, 0, indices_of_best) + return best_rows, best_evals + + +def take_best( + values: torch.Tensor, + evals: torch.Tensor, + n: Optional[int] = None, + *, + objective_sense: Union[str, list], + crowdsort: bool = True, +) -> tuple: """ Take the best solution, or the best `n` number of solutions. + **Single-objective case.** + If the positional argument `n` is omitted (i.e. is left as None), the + decision values and the evaluation result of the single best solution + will be returned. + If the positional argument `n` is provided, top-`n` solutions, together + with their evaluation results, will be returned. + + **Multi-objective case.** + In the multi-objective case, the positional argument `n` is mandatory. + With a valid value for `n` given, `n` number of solutions will be taken + from the best pareto-fronts. If `crowdsort` is given as True (which is + the default), crowding distances of the solutions within the same + pareto-fronts will be an additional criterion when deciding which + solutions to take. Like in the single-objective case, the decision values + and the evaluation results of the taken solutions will be returned. + Args: values: Decision values tensor, with at least 2 dimensions. Extra leftmost dimensions will be taken as batch dimensions. @@ -1136,11 +1598,43 @@ def take_best(values: torch.Tensor, evals: torch.Tensor, n: Optional[int] = None Extra leftmost dimensions will be taken as batch dimensions. n: If left as None, the single best solution will be taken. If given as an integer, this number of best solutions will be - taken. - objective_sense: A string whose value is either 'min' or 'max', - representing the goal of the optimization. + taken. Please note that, if the problem at hand has multiple + objectives, this argument cannot be omitted. + objective_sense: In the single-objective case, `objective_sense` is + expected as a string 'min' or 'max', representing the direction + of the optimization. In the multi-objective case, + `objective_sense` is expected as a list of strings, each string + being 'min' or 'max', representing the goal of optimization for + each objective. + crowdsort: Relevant only when there are multiple objectives. + If `crowdsort` is True, the crowding distances of the solutions + within the given population will be an additional criterion + when choosing the best `n` solution. If `crowdsort` is False, + how many times a solution was dominated will be the only factor + when deciding whether or not it is among the best `n` solutions. + Returns: + A tuple of the form `(decision_values, evaluation_results)`, where + `decision_values` is the decision values tensor for the taken + solution(s), and `evaluation_results` is the evaluation results tensor + for the taken solution(s). """ + if isinstance(objective_sense, str): + multi_objective = False + elif isinstance(objective_sense, Iterable): + multi_objective = True + else: + raise TypeError("Unrecognized `objective_sense`: {repr(objective_sense)}") + if n is None: + if multi_objective: + raise ValueError( + "`objective_sense` not given as a string, implying that there are multiple objectives." + " When there are multiple objectives, the argument `n` (i.e. number of solutions to take)" + " must not be omitted. However, `n` was encountered as None." + ) return _take_single_best(values, evals, objective_sense) else: - return _take_multiple_best(values, evals, n, objective_sense) + if multi_objective: + return _take_multiple_best_with_multiobjective(values, evals, n, objective_sense, crowdsort) + else: + return _take_multiple_best(values, evals, n, objective_sense) From c95af8701c031236fcd87ca5ec0ac6e9f379cf3b Mon Sep 17 00:00:00 2001 From: Nihat Engin Toklu Date: Mon, 19 Aug 2024 17:47:15 +0200 Subject: [PATCH 3/6] Add ObjectArray support to functional genetic algorithm operators --- src/evotorch/core.py | 109 +++- src/evotorch/operators/functional.py | 737 ++++++++++++++++++++++---- src/evotorch/tools/immutable.py | 8 +- tests/test_func_ops.py | 740 +++++++++++++++++++++++++++ 4 files changed, 1464 insertions(+), 130 deletions(-) create mode 100644 tests/test_func_ops.py diff --git a/src/evotorch/core.py b/src/evotorch/core.py index f44d0efc..5acb6194 100644 --- a/src/evotorch/core.py +++ b/src/evotorch/core.py @@ -3367,7 +3367,7 @@ def make_callable_evaluator(self, *, obj_index: Optional[int] = None) -> "Proble ``` **Parallelized fitness evaluation.** - If a `Problem` object is condifured to use parallelized evaluation with + If a `Problem` object is configured to use parallelized evaluation with the help of multiple actors, a callable evaluator made out of that `Problem` object will also make use of those multiple actors. @@ -3375,7 +3375,8 @@ def make_callable_evaluator(self, *, obj_index: Optional[int] = None) -> "Proble If a callable evaluator receives a tensor with 3 or more dimensions, those extra leftmost dimensions will be considered as batch dimensions. The returned fitness tensor will also preserve those batch - dimensions. + dimensions. Please note, however, that if the `dtype` of the problem + is `object`, additional batch dimensions are not supported. **Notes on vmap.** `ProblemBoundEvaluator` is a shallow wrapper around a `Problem` object. @@ -3388,12 +3389,20 @@ def make_callable_evaluator(self, *, obj_index: Optional[int] = None) -> "Proble Args: obj_index: The index of the objective according to which the evaluations will be done. If the problem is single-objective, - this is not required. If the problem is multi-objective, this - needs to be given as an integer. + `obj_index` can be omitted. If the problem is multi-objective + and `obj_index` is omitted, the callable evaluator will return + multi-dimensional tensors that express the fitnesses for all + objectives (where the rightmost dimension size is equal to + the number of objectives). If the problem is multi-objective + and `obj_index` is given, the callable evaluator will return + tensors that express the fitnesses of the specified objective. Returns: A callable fitness evaluator, bound to this problem object. """ - return ProblemBoundEvaluator(self, obj_index=obj_index) + if self.dtype is object: + return ObjectTypedProblemBoundEvaluator(self, obj_index=obj_index) + else: + return ProblemBoundEvaluator(self, obj_index=obj_index) SolutionBatchSliceInfo = NamedTuple("SolutionBatchSliceInfo", source="SolutionBatch", slice=IndicesOrSlice) @@ -5114,20 +5123,28 @@ def __init__(self, problem: Problem, *, obj_index: Optional[int] = None): is multi-objective, this is expected as an integer. """ self._problem = problem + self._ensure_valid_problem() + if self._problem.is_multi_objective and (obj_index is None): + self._multi_objective = True + self._obj_index = None + else: + self._multi_objective = False + self._obj_index = self._problem.normalize_obj_index(obj_index) + self._problem.ensure_numeric() + # if problem.dtype != problem.eval_dtype: + # raise TypeError( + # "The dtype of the decision values is not the same with the dtype of the evaluations." + # " Currently, it is not supported to make callable evaluators out of problems whose" + # " decision value dtypes are different than their evaluation dtypes." + # ) + + def _ensure_valid_problem(self): if not isinstance(self._problem, Problem): clsname = type(self).__name__ raise TypeError( f"In its initialization phase, {clsname} expected a `Problem` object," f" but found: {repr(self._problem)} (of type {repr(type(self._problem))})" ) - self._obj_index = self._problem.normalize_obj_index(obj_index) - self._problem.ensure_numeric() - if problem.dtype != problem.eval_dtype: - raise TypeError( - "The dtype of the decision values is not the same with the dtype of the evaluations." - " Currently, it is not supported to make callable evaluators out of problems whose" - " decision value dtypes are different than their evaluation dtypes." - ) def _make_empty_solution_batch(self, popsize: int) -> SolutionBatch: return SolutionBatch(self._problem, popsize=popsize, empty=True, device="meta") @@ -5165,6 +5182,68 @@ def __call__(self, values: torch.Tensor) -> torch.Tensor: values = values.reshape(-1, solution_length) evaluated_batch = self._prepare_evaluated_solution_batch(values) - evals = evaluated_batch.evals[:, self._obj_index] - return evals.reshape(original_batch_shape).as_subclass(torch.Tensor) + if self._multi_objective: + evals = evaluated_batch.evals + num_objs = evals.shape[-1] + original_evals_shape = tuple([*original_batch_shape, num_objs]) + else: + evals = evaluated_batch.evals[:, self._obj_index] + original_evals_shape = original_batch_shape + return evals.reshape(original_evals_shape).as_subclass(torch.Tensor) + + +class ObjectTypedProblemBoundEvaluator(ProblemBoundEvaluator): + """ + A callable fitness evaluator, bound to a `Problem` whose dtype is object. + + A callable evaluator returned by the method + `Problem.make_callable_evaluator` is an instance of this class, if the + dtype of the problem is `object`. + For details, please see the documentation of + [Problem][evotorch.core.Problem], and of its method + `make_callable_evaluator`. + """ + + def __init__(self, problem: Problem, *, obj_index: Optional[int] = None): + self._problem = problem + self._ensure_valid_problem() + if self._problem.is_multi_objective and (obj_index is None): + self._multi_objective = True + self._obj_index = None + else: + self._multi_objective = False + self._obj_index = self._problem.normalize_obj_index(obj_index) + if self._problem.dtype is not object: + raise TypeError( + "Expected a problem whose dtype is `object`." + f" However, the dtype of the problem is {self._problem.dtype}." + " Hint: did you mean to instantiate a `ProblemBoundEvaluator`, instead of an" + " `ObjectTypedProblemBoundEvaluator`?" + ) + + def _make_empty_solution_batch(self, popsize: int) -> SolutionBatch: + return SolutionBatch(self._problem, popsize=popsize, empty=True, device="cpu") + + def _prepare_evaluated_solution_batch(self, values: ObjectArray) -> SolutionBatch: + num_solutions = len(values) + batch = self._make_empty_solution_batch(num_solutions) + batch.access_values()[:] = values + self._problem.evaluate(batch) + return batch + + def __call__(self, values: ObjectArray) -> torch.Tensor: + """ + Evaluate the solutions expressed by the ObjectArray-typed `values`. + + Args: + values: Decision values. Expected as an `ObjectArray`. + Returns: + The fitnesses, as a tensor. + """ + if not isinstance(values, ObjectArray): + raise TypeError( + "The positional argument `values` was expected as an `ObjectArray`." + f" However, an object of this type was encountered: {type(values)}." + ) + return self._prepare_evaluated_solution_batch(values).evals.as_subclass(torch.Tensor) diff --git a/src/evotorch/operators/functional.py b/src/evotorch/operators/functional.py index c089b6c4..ed3bb329 100644 --- a/src/evotorch/operators/functional.py +++ b/src/evotorch/operators/functional.py @@ -133,11 +133,12 @@ def f(x: torch.Tensor) -> torch.Tensor: """ -from typing import Iterable, Optional, Union +from typing import Iterable, NamedTuple, Optional, Union import torch from evotorch.decorators import expects_ndim +from evotorch.tools import ObjectArray def _index_comparison_matrices(n: int, *, device: Union[str, torch.dtype]) -> tuple: @@ -453,96 +454,451 @@ def pareto_utility(evals: torch.Tensor, *, objective_sense: list, crowdsort: boo return _pareto_utility(evals, objective_sense, crowdsort) -@expects_ndim(2, 1, 1, None, randomness="different") -def _pick_solution_via_tournament( +@expects_ndim(2, None, None, randomness="different") +def _generate_first_parent_candidate_indices( solutions: torch.Tensor, - evals: torch.Tensor, - indices: torch.Tensor, - objective_sense: str, + num_tournaments: int, + tournament_size: int, +) -> torch.Tensor: + # We are considering half of the given `num_tournaments`, because a second set of tournaments will later + # be executed to pick the second parents. This current operation is only for the first set (and therefore the + # first half) of the parents. + num_tournaments = int(num_tournaments) + if (num_tournaments % 2) != 0: + raise ValueError( + f"`num_tournaments` was expected as a number divisible by 2. However, its value is {num_tournaments}." + ) + half_num_tournaments = num_tournaments // 2 + + num_solutions, _ = solutions.shape + return torch.randint(0, num_solutions, (half_num_tournaments, tournament_size), device=solutions.device) + + +@expects_ndim(None, 1, 0, randomness="different") +def _generate_second_parent_candidate_indices( + num_solutions: int, + parent1_candidate_indices: torch.Tensor, + parent1_winner_index: torch.Tensor, +) -> torch.Tensor: + parent2_candidate_indices = torch.randint_like(parent1_candidate_indices, 0, num_solutions - 1) + parent2_candidate_indices = torch.where( + parent2_candidate_indices >= parent1_winner_index, + parent2_candidate_indices + 1, + parent2_candidate_indices, + ) + return parent2_candidate_indices + + +@expects_ndim(1, None, 1, randomness="different") +def _run_two_tournaments_using_utilities( + utilities: torch.Tensor, + higher_utility_is_better: bool, + parent1_candidate_indices: torch.Tensor, ) -> tuple: - """ - Run a single tournament among multiple solutions to pick the best. + argbest = torch.argmax if higher_utility_is_better else torch.argmin + parent1_candidate_evals = torch.index_select(utilities, 0, parent1_candidate_indices) + winner1_indirect_index = argbest(parent1_candidate_evals) + winner1_index = torch.index_select(parent1_candidate_indices, 0, winner1_indirect_index.reshape(1))[0] + + [num_solutions] = utilities.shape + parent2_candidate_indices = _generate_second_parent_candidate_indices( + num_solutions, parent1_candidate_indices, winner1_index + ) - Args: - solutions: Decision values of the solutions, as a tensor of at least - 2 dimensions. Extra leftmost dimensions will be considered as - batch dimensions. - evals: Evaluation results (i.e. fitnesses) of the solutions, as a - tensor with at least 1 dimension. Extra leftmost dimensions will - be considered as batch dimensions. - indices: Indices of solutions that participate into the tournament, - as a tensor of integers with at least 1 dimension. Extra leftmost - dimensions will be considered as batch dimensions. - objective_sense: A string with value 'min' or 'max', representing the - goal of the optimization. - Returns: - A tuple of the form `(decision_values, eval_result)` where - `decision_values` is the tensor that contains the decision values - of the winning solution(s), and `eval_result` is a tensor that - contains the evaluation result(s) (i.e. fitness(es)) of the - winning solution(s). - """ - # Get the evaluation results of the solutions that participate into the tournament - competing_evals = torch.index_select(evals, 0, indices) - - if objective_sense == "max": - # If the objective sense is 'max', we are looking for the solution with the highest evaluation result - argbest = torch.argmax - elif objective_sense == "min": - # If the objective sense is 'min', we are looking for the solution with the lowest evaluation result - argbest = torch.argmin + parent2_candidate_evals = torch.index_select(utilities, 0, parent2_candidate_indices) + winner2_indirect_index = argbest(parent2_candidate_evals) + winner2_index = torch.index_select(parent2_candidate_indices, 0, winner2_indirect_index.reshape(1))[0] + + return winner1_index, winner2_index + + +class SelectedParentIndices(NamedTuple): + parent1_indices: torch.Tensor + parent2_indices: torch.Tensor + + +class SelectedParentValues(NamedTuple): + parent1_values: Union[torch.Tensor, ObjectArray] + parent2_values: Union[torch.Tensor, ObjectArray] + + +class SelectedParents(NamedTuple): + parent1_values: Union[torch.Tensor, ObjectArray] + parent1_evals: torch.Tensor + parent2_values: Union[torch.Tensor, ObjectArray] + parent2_evals: torch.Tensor + + +class SelectedAndStackedParents(NamedTuple): + parent_values: Union[torch.Tensor, ObjectArray] + parent_evals: torch.Tensor + + +def _undecorated_take_solutions( + solutions: torch.Tensor, + evals: torch.Tensor, + parent1_indices: torch.Tensor, + parent2_indices: torch.Tensor, + with_evals: bool, + split_results: bool, + multi_objective: bool, +) -> Union[torch.Tensor, tuple]: + parent1_values = solutions[parent1_indices] + parent2_values = solutions[parent2_indices] + + if with_evals: + if split_results: + return SelectedParents( + parent1_values=parent1_values, + parent1_evals=evals[parent1_indices], + parent2_values=parent2_values, + parent2_evals=evals[parent2_indices], + ) + else: + combine_evals_fn = torch.vstack if multi_objective else torch.cat + return SelectedAndStackedParents( + parent_values=torch.vstack([parent1_values, parent2_values]), + parent_evals=combine_evals_fn([evals[parent1_indices], evals[parent2_indices]]), + ) else: - raise ValueError( - "`objective_sense` was expected either as 'min' or as 'max'." - f" However, it was received as {repr(objective_sense)}." - ) + if split_results: + return SelectedParentValues( + parent1_values=parent1_values, + parent2_values=parent2_values, + ) + else: + return torch.vstack([parent1_values, parent2_values]) - # Among the competing solutions, which one is the best? - winner_competing_eval_index = argbest(competing_evals) - # Get the index (within the original `solutions`) of the winning solution - winner_solution_index = torch.index_select(indices, 0, winner_competing_eval_index.reshape(1)) +@expects_ndim(2, 1, 1, 1, None, None) +def _take_solutions_with_single_objective( + solutions: torch.Tensor, + evals: torch.Tensor, + parent1_indices: torch.Tensor, + parent2_indices: torch.Tensor, + with_evals: bool, + split_results: bool, +) -> Union[torch.Tensor, tuple]: + return _undecorated_take_solutions( + solutions, evals, parent1_indices, parent2_indices, with_evals, split_results, False + ) - # Get the decision values and the evaluation result of the winning solution - winner_solution = torch.squeeze(torch.index_select(solutions, 0, winner_solution_index), dim=0) - winner_eval = torch.squeeze(torch.index_select(evals, 0, winner_solution_index), dim=0) - # Return the winning solution's decision values and evaluation results - return winner_solution, winner_eval +@expects_ndim(2, 2, 1, 1, None, None) +def _take_solutions_with_multi_objective( + solutions: torch.Tensor, + evals: torch.Tensor, + parent1_indices: torch.Tensor, + parent2_indices: torch.Tensor, + with_evals: bool, + split_results: bool, +) -> Union[torch.Tensor, tuple]: + return _undecorated_take_solutions( + solutions, evals, parent1_indices, parent2_indices, with_evals, split_results, True + ) -@expects_ndim(2, 1, None, None, None, randomness="different") -def _single_objective_tournament( +@expects_ndim(2, 1, None, None, None, None, None, None, randomness="different") +def _pick_pairs_via_tournament_with_single_objective( solutions: torch.Tensor, evals: torch.Tensor, num_tournaments: int, tournament_size: int, objective_sense: str, + return_indices: bool, + with_evals: bool, + split_results: bool, +) -> Union[torch.Tensor, tuple]: + num_solutions, _ = solutions.shape + [num_evals] = evals.shape + if num_solutions != num_evals: + raise ValueError("Number of evaluation results does not match the number of solutions") + + if objective_sense == "min": + higher_utility_is_better = False + elif objective_sense == "max": + higher_utility_is_better = True + else: + raise ValueError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") + + first_parent_indices = _generate_first_parent_candidate_indices(solutions, num_tournaments, tournament_size) + winner1_indices, winner2_indices = _run_two_tournaments_using_utilities( + evals, higher_utility_is_better, first_parent_indices + ) + + if return_indices: + if split_results: + return SelectedParentIndices(parent1_indices=winner1_indices, parent2_indices=winner2_indices) + else: + return torch.cat([winner1_indices, winner2_indices]) + else: + return _take_solutions_with_single_objective( + solutions, evals, winner1_indices, winner2_indices, with_evals, split_results + ) + + +@expects_ndim(2, 2, None, None, None, None, None, None, randomness="different") +def _pick_pairs_via_tournament_with_multi_objective( + solutions: torch.Tensor, + evals: torch.Tensor, + num_tournaments: int, + tournament_size: int, + objective_sense: list, + return_indices: bool, + with_evals: bool, + split_results: bool, +) -> Union[torch.Tensor, tuple]: + num_solutions, _ = solutions.shape + num_evals, _ = evals.shape + if num_solutions != num_evals: + raise ValueError("Number of evaluation results does not match the number of solutions") + + utils = pareto_utility(evals, objective_sense=objective_sense, crowdsort=False) + first_parent_indices = _generate_first_parent_candidate_indices(solutions, num_tournaments, tournament_size) + winner1_indices, winner2_indices = _run_two_tournaments_using_utilities(utils, True, first_parent_indices) + + if return_indices: + if split_results: + return SelectedParentIndices(parent1_indices=winner1_indices, parent2_indices=winner2_indices) + else: + return torch.cat([winner1_indices, winner2_indices]) + else: + return _take_solutions_with_multi_objective( + solutions, evals, winner1_indices, winner2_indices, with_evals, split_results + ) + + +def _pick_pairs_via_tournament_considering_objects( + solutions: ObjectArray, + evals: torch.Tensor, + num_tournaments: int, + tournament_size: int, + objective_sense: Union[list, str], + return_indices: bool, + with_evals: bool, + split_results: bool, ) -> tuple: - if tournament_size < 1: + from evotorch.tools import make_tensor + + num_solutions = len(solutions) + if isinstance(objective_sense, str): + multi_objective = False + if evals.ndim != 1: + raise ValueError( + "In the case of single-objective optimization, `evals` was expected as a 1-dimensional tensor." + f" However, the shape of `evals` is {evals.shape}." + ) + [num_evals] = evals.shape + utils = evals + if objective_sense == "min": + higher_utility_is_better = False + elif objective_sense == "max": + higher_utility_is_better = True + else: + raise ValueError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") + elif isinstance(objective_sense, Iterable): + multi_objective = True + if evals.ndim != 2: + raise ValueError( + "In the case of multi-objective optimization, `evals` was expected as a 2-dimensional tensor." + f" However, the shape of `evals` is {evals.shape}." + ) + multi_objective = True + num_evals, _ = evals.shape + utils = pareto_utility(evals, objective_sense=objective_sense, crowdsort=False) + higher_utility_is_better = True + else: + raise TypeError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") + + if num_solutions != num_evals: + raise ValueError("Number of evaluation results does not match the number of solutions") + + num_tournaments = int(num_tournaments) + if (num_tournaments % 2) != 0: raise ValueError( - "The argument `tournament_size` was expected to be greater than or equal to 1." - f" However, it was encountered as {tournament_size}." + f"`num_tournaments` was expected as a number divisible by 2. However, its value is {num_tournaments}." ) - popsize, _ = solutions.shape - indices_for_tournament = torch.randint_like( - solutions[:1, :1].expand(num_tournaments, tournament_size), 0, popsize, dtype=torch.int64 + half_num_tournaments = num_tournaments // 2 + first_parent_indices = torch.randint(0, num_solutions, (half_num_tournaments, tournament_size), device=evals.device) + + winner1_indices, winner2_indices = _run_two_tournaments_using_utilities( + utils, higher_utility_is_better, first_parent_indices ) - return _pick_solution_via_tournament(solutions, evals, indices_for_tournament, objective_sense) + if return_indices: + if split_results: + return SelectedParentIndices(parent1_indices=winner1_indices, parent2_indices=winner2_indices) + else: + return torch.cat([winner1_indices, winner2_indices]) + else: + parent1_values = solutions[torch.as_tensor(winner1_indices, device="cpu")] + parent2_values = solutions[torch.as_tensor(winner2_indices, device="cpu")] + if split_results: + combined_values = None + else: + combined_values = make_tensor( + [*parent1_values, *parent2_values], read_only=solutions.is_read_only, dtype=object + ) -def _tournament( - solutions: torch.Tensor, + if with_evals: + if split_results: + return SelectedParents( + parent1_values=parent1_values, + parent1_evals=evals[winner1_indices], + parent2_values=parent2_values, + parent2_evals=evals[winner2_indices], + ) + else: + evals_combiner_fn = torch.vstack if multi_objective else torch.cat + combined_evals = evals_combiner_fn([evals[winner1_indices], evals[winner2_indices]]) + return SelectedAndStackedParents(parent_values=combined_values, parent_evals=combined_evals) + else: + if split_results: + return SelectedParentValues( + parent1_values=parent1_values, + parent2_values=parent2_values, + ) + else: + return combined_values + + +TournamentResult = Union[ + SelectedParentIndices, + SelectedParentValues, + SelectedParents, + SelectedAndStackedParents, + torch.Tensor, + ObjectArray, +] + + +def tournament( + solutions: Union[torch.Tensor, ObjectArray], evals: torch.Tensor, + *, num_tournaments: int, tournament_size: int, objective_sense: Union[str, list], -) -> tuple: + return_indices: bool = False, + with_evals: bool = False, + split_results: bool = False, +) -> TournamentResult: """ - Randomly pick solutions, put them into a tournament, pick the winners. + Randomly organize pairs of tournaments and pick the winning solutions. + + Hyperparameters regarding the tournament selection are + `num_tournaments` (number of tournaments), and `tournament_size` + (size of each tournament). + + **How does each tournament work?** + `tournament_size` number of solutions are randomly sampled from the given + `solutions`, and then, the best solution among the sampled solutions is + declared the winner. In the case of single-objective optimization, the + best solution is the one with the best evaluation result (i.e. best + fitness). In the case of multi-objective optimization, the best solution + is the one within the best pareto-front. + + **How are multiple tournaments are organized?** + Two sets of tournaments are organized. Each set contains `n` number of + tournaments, `n` being the half of `num_tournaments`. + For example, let us assume that `num_tournaments` is 6. Then, we have: + + ```text + First set of tournaments : tournamentA, tournamentB, tournamentC + Second set of tournaments : tournamentD, tournamentE, tournamentF + ``` + + In this organization of tournaments, the winner of tournamentA is meant + for cross-over with the winner of tournamentD; the winner of tournamentB + is meant for cross-over with the winner of tournamentE; and the winner of + tournamentC is meant for cross-over with the winner of tournamentF. + + While sampling the participants for these tournaments, it is ensured that + the winner of tournamentA does not participate into tournamentD; the + winner of tournamentB does not participate into tournamentE; and the + winner of tournamentC does not participate into tournamentF. Therefore, + each cross-over operation is applied on two different parent solutions. + + **How are the tournament results represented?** + The tournament results are returned in various forms. These various forms + are as follows. + + **Results in the form of decision values.** + This is the default form of results (with `return_indices=False`, + `with_evals=False`, `split_results=False`). Here, the results are + expressed as a single tensor (or `ObjectArray`) of decision values. + The first half of these decision values represent the first set of + parents, and the second half of these decision values represent the second + half of these decision values represent the second set of parents. + For example, let us assume that the number of tournaments + (`num_tournaments`) is configured as 6. In this case, the result is a + decision values tensor with 6 rows (or an `ObjectArray` of length 6). + In these results (let us call them `resulting_values`), the pairings + for the cross-over operations are as follows: + - `resulting_values[0]` and `resulting_values[3]`; + - `resulting_values[1]` and `resulting_values[4]`; + - `resulting_values[2]` and `resulting_values[5]`. + + **Results in the form of indices.** + This form of results can be taken with arguments `return_indices=True`, + `with_evals=False`, `split_results=False`. Here, the results are + expressed as a single tensor of integers, each integer being the index + of a solution within `solutions`. + For example, let us assume that the number of tournaments + (`num_tournaments`) is configured as 6. In this case, the result is a + tensor of indices of length 6. + In these results (let us call them `resulting_indices`), the pairings + for the cross-over operations are as follows: + - `resulting_indices[0]` and `resulting_indices[3]`; + - `resulting_indices[1]` and `resulting_indices[4]`; + - `resulting_indices[2]` and `resulting_indices[5]`. + + **Results in the form of decision values and evaluations.** + This form of results can be taken with arguments `return_indices=False`, + `with_evals=True`, `split_results=False`. Here, the results are expressed + via a named tuple in the form `(parent_values=..., parent_evals=...)`. + In this tuple, `parent_values` stores a tensor (or an `ObjectArray`) + representing the decision values of the picked solutions, and + `parent_evals` stores the evaluation results as a tensor. + For example, let us assume that the number of tournaments + (`num_tournaments`) is 6. With this assumption, in the returned named + tuple (let us call it `result`), the pairings for the cross-over + operations are as follows: + - `result.parent_values[0]` and `result.parent_values[3]`; + - `result.parent_values[1]` and `result.parent_values[4]`; + - `result.parent_values[2]` and `result.parent_values[5]`. + For any solution `result.parent_values[i]`, the evaluation result + is stored by `result.parent_evals[i]`. + + **Results with split parent solutions.** + This form of results can be taken with arguments `return_indices=False`, + `with_evals=False`, `split_results=True`. The returned object is a + named tuple in the form `(parent1_values=..., parent2_values=...)`. + In the returned named tuple (let us call it `result`), the pairings for + the cross-over operations are as follows: + - `result.parent1_values[0]` and `result.parent2_values[0]`; + - `result.parent1_values[1]` and `result.parent2_values[1]`; + - `result.parent1_values[2]` and `result.parent2_values[2]`; + - and so on... + + **Results with split parent solutions and evaluations.** + This form of results can be taken with arguments `return_indices=False`, + `with_evals=True`, `split_results=True`. The returned object is a + named tuple, its attributes being `parent1_values`, `parent1_evals`, + `parent2_values`, and `parent2_evals`. + In the returned named tuple (let us call it `result`), the pairings for + the cross-over operations are as follows: + - `result.parent1_values[0]` and `result.parent2_values[0]`; + - `result.parent1_values[1]` and `result.parent2_values[1]`; + - `result.parent1_values[2]` and `result.parent2_values[2]`; + - and so on... + For any solution `result.parent_values[i]`, the evaluation result + is stored by `result.parent_evals[i]`. Args: - solutions: Decision values of the solutions. + solutions: Decision values of the solutions. Can be a tensor with + at least 2 dimensions (where extra leftmost dimensions are to be + interpreted as batch dimensions), or an `ObjectArray`. evals: Evaluation results of the solutions. In the single-objective case, this is expected as an at-least-1-dimensional tensor, the `i`-th item expressing @@ -558,26 +914,37 @@ def _tournament( objective_sense: A string or a list of strings, where (each) string has either the value 'min' for minimization or 'max' for maximization. + return_indices: If this is given as True, indices of the selected + solutions will be returned, instead of their decision values. + with_evals: If this is given as True, evaluations of the selected + solutions will be returned in addition to their decision values. + split_results: If this is given as True, tournament results will be + split as first parents and second parents. If this is given as + False, results will be stacked vertically, in the sense that + the first half of the results are the first parents and the + second half of the results are the second parents. Returns: - A tuple of the form `(decision_values, eval_results)` where - `decision_values` is the tensor that contains the decision values - of the winning solutions, and `eval_result` is a tensor that - contains the evaluation results (i.e. fitnesses) of the - winning solutions. + Selected solutions (or their indices, with or without their + evaluation results). """ - if isinstance(objective_sense, str): - pass # nothing to do - elif isinstance(objective_sense, Iterable): - objective_sense = list(objective_sense) - evals = pareto_utility(evals, objective_sense=objective_sense, crowdsort=False) - objective_sense = "max" - else: - raise TypeError( - "The argument `objective_sense` was expected as a string for the single-objective case," - " or as a list of strings for the multi-objective case." - f" However, the encountered `objective_sense` is {repr(objective_sense)}." + if return_indices and with_evals: + raise ValueError( + "When `return_indices` is given as True, `with_evals` must be False." + " However, `with_evals` was encountered as True." ) - return _single_objective_tournament(solutions, evals, num_tournaments, tournament_size, objective_sense) + + if isinstance(solutions, ObjectArray): + pick_fn = _pick_pairs_via_tournament_considering_objects + elif isinstance(solutions, torch.Tensor): + if isinstance(objective_sense, str): + pick_fn = _pick_pairs_via_tournament_with_single_objective + elif isinstance(objective_sense, Iterable): + pick_fn = _pick_pairs_via_tournament_with_multi_objective + else: + raise TypeError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") + return pick_fn( + solutions, evals, num_tournaments, tournament_size, objective_sense, return_indices, with_evals, split_results + ) @expects_ndim(2, randomness="different") @@ -767,7 +1134,14 @@ def multi_point_cross_over( ) # Apply tournament selection on the original `parents` - parents, _ = _tournament(parents, evals, num_children, tournament_size, objective_sense) + parents = tournament( + parents, + evals, + num_tournaments=num_children, + tournament_size=tournament_size, + objective_sense=objective_sense, + with_evals=False, + ) # Apply the cross-over operation on `parents`, and return the recombined decision values tensor. return _do_cross_over(parents, num_points) @@ -1080,7 +1454,14 @@ def simulated_binary_cross_over( ) # Apply tournament selection on the original `parents` - parents, _ = _tournament(parents, evals, num_children, tournament_size, objective_sense) + parents = tournament( + parents, + evals, + num_tournaments=num_children, + tournament_size=tournament_size, + objective_sense=objective_sense, + with_evals=False, + ) return _do_sbx(parents, eta) @@ -1387,17 +1768,54 @@ def _combine_values_and_multiobjective_evals( return torch.vstack([values1, values2]), torch.vstack([evals1, evals2]) +def _combine_object_arrays(values1: ObjectArray, values2: ObjectArray) -> ObjectArray: + from evotorch.tools import make_tensor + + read_only = values1.is_read_only or values2.is_read_only + return make_tensor([*values1, *values2], dtype=object, read_only=read_only) + + +def _combine_object_arrays_and_evals( + values1: ObjectArray, evals1: torch.Tensor, values2: ObjectArray, evals2: torch.Tensor +) -> tuple: + eval_shapes_are_valid = (evals1.ndim == 1) and (evals2.ndim == 1) + if not eval_shapes_are_valid: + raise ValueError( + "Evaluation result tensors were expected to have only 1 dimension each." + f" However, their shapes are {evals1.shape} and {evals2.shape}." + ) + return _combine_object_arrays(values1, values2), torch.cat([evals1, evals2]) + + +def _combine_object_arrays_and_multiobjective_evals( + values1: ObjectArray, evals1: torch.Tensor, values2: ObjectArray, evals2: torch.Tensor +) -> tuple: + eval_shapes_are_valid = (evals1.ndim == 2) and (evals2.ndim == 2) + if not eval_shapes_are_valid: + raise ValueError( + "Evaluation result tensors were expected to have 2 dimensions each." + f" However, their shapes are {evals1.shape} and {evals2.shape}." + ) + return _combine_object_arrays(values1, values2), torch.vstack([evals1, evals2]) + + +def _both_are_tensors(a, b) -> bool: + return isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor) + + +def _both_are_object_arrays(a, b) -> bool: + return isinstance(a, ObjectArray) and isinstance(b, ObjectArray) + + def combine( - a: Union[torch.Tensor, tuple], - b: Union[torch.Tensor, tuple], + a: Union[torch.Tensor, ObjectArray, tuple], + b: Union[torch.Tensor, ObjectArray, tuple], *, objective_sense: Optional[Union[str, Iterable]] = None, ) -> Union[torch.Tensor, tuple]: """ Combine two populations into one. - This function can be used in two forms. - **Usage 1: without evaluation results.** Let us assume that we have two decision values matrices, `values1` `values2`. The shapes of these matrices are (n1, L) and (n2, L) @@ -1451,23 +1869,35 @@ def combine( # `c_values` is shaped (n1+n2, L), and `c_evals` is shaped (n1+n2,). ``` + **Support for ObjectArray.** + This function supports decision values that are expressed via instances + of `ObjectArray`. + Args: - a: A decision values tensor with at least 2 dimensions, or a tuple - of the form `(values, evals)`, where `values` is an at least - 2-dimensional decision values tensor, and `evals` is an at least - 1-dimensional evaluation results tensor. - Extra leftmost dimensions are taken as batch dimensions. + a: A decision values tensor with at least 2 dimensions, or an + `ObjectArray` of decision values, or a tuple of the form + `(values, evals)` where `values` is the decision values + and `evals` is a tensor with at least 1 dimension. + Additional leftmost dimensions within tensors are interpreted + as batch dimensions. If this positional argument is a tensor, the second positional - argument must also be a tensor. If this positional argument is a - tuple, the second positional argument must also be a tuple. - b: A decision values tensor with at least 2 dimensions, or a tuple - of the form `(values, evals)`, where `values` is an at least - 2-dimensional decision values tensor, and `evals` is an at least - 1-dimensional evaluation results tensor. - Extra leftmost dimensions are taken as batch dimensions. + argument must also be a tensor. + If this positional argument is an `ObjectArray`, the second + positional argument must also be an `ObjectArray`. + If this positional argument is a tuple, the second positional + argument must also be a tuple. + b: A decision values tensor with at least 2 dimensions, or an + `ObjectArray` of decision values, or a tuple of the form + `(values, evals)` where `values` is the decision values + and `evals` is a tensor with at least 1 dimension. + Additional leftmost dimensions within tensors are interpreted + as batch dimensions. If this positional argument is a tensor, the first positional - argument must also be a tensor. If this positional argument is a - tuple, the first positional argument must also be a tuple. + argument must also be a tensor. + If this positional argument is an `ObjectArray`, the first + positional argument must also be an `ObjectArray`. + If this positional argument is a tuple, the first positional + argument must also be a tuple. objective_sense: In the case of single-objective optimization, `objective_sense` can be left as None, or can be 'min' or 'max', representing the direction of the optimization. @@ -1494,19 +1924,43 @@ def combine( ) values2, evals2 = b if (objective_sense is None) or isinstance(objective_sense, str): - return _combine_values_and_evals(values1, evals1, values2, evals2) + if _both_are_tensors(values1, values2): + return _combine_values_and_evals(values1, evals1, values2, evals2) + elif _both_are_object_arrays(values1, values2): + return _combine_object_arrays_and_evals(values1, evals1, values2, evals2) + else: + raise TypeError( + "Both decision values arrays must be `Tensor`s or `ObjectArray`s." + f" However, their types are: {type(values1)}, {type(values2)}." + ) elif isinstance(objective_sense, Iterable): - return _combine_values_and_multiobjective_evals(values1, evals1, values2, evals2) + if _both_are_tensors(values1, values2): + return _combine_values_and_multiobjective_evals(values1, evals1, values2, evals2) + elif _both_are_object_arrays(values1, values2): + return _combine_object_arrays_and_multiobjective_evals(values1, evals1, values2, evals2) + else: + raise TypeError( + "Both decision values arrays must be `Tensor`s or `ObjectArray`s." + f" However, their types are: {type(values1)}, {type(values2)}." + ) else: raise TypeError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") - elif isinstance(a, torch.Tensor): - if not isinstance(b, torch.Tensor): + elif isinstance(a, (torch.Tensor, ObjectArray)): + if not isinstance(b, (torch.Tensor, ObjectArray)): raise TypeError( - "The first positional argument was received as a tensor." - " Therefore, the second positional argument was also expected as a tensor." + "The first positional argument was received as a tensor or ObjectArray." + " Therefore, the second positional argument was also expected as a tensor or ObjectArray." f" However, the second argument is {repr(b)} (of type {type(b)})." ) - return _combine_values(a, b) + if _both_are_tensors(a, b): + return _combine_values(a, b) + elif _both_are_object_arrays(a, b): + return _combine_object_arrays(a, b) + else: + raise TypeError( + "Both decision values arrays must be `Tensor`s or `ObjectArray`s." + f" However, their types are: {type(values1)}, {type(values2)}." + ) else: raise TypeError( "Expected both positional arguments as tensors, or as tuples." @@ -1564,8 +2018,56 @@ def _take_multiple_best_with_multiobjective( return best_rows, best_evals +def _take_best_considering_objects( + values: ObjectArray, + evals: torch.Tensor, + n: Optional[int], + objective_sense: Union[str, list], + crowdsort: bool, +) -> ObjectArray: + if isinstance(objective_sense, str): + if evals.ndim != 1: + raise ValueError( + "The given `objective_sense` implies that there is only one objective." + " In this case, `evals` was expected to have only one dimension." + f" However, the shape of `evals` is {evals.shape}." + ) + multi_objective = False + if objective_sense == "min": + descending = False + elif objective_sense == "max": + descending = True + else: + raise ValueError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") + utils = evals + elif isinstance(objective_sense, Iterable): + if evals.ndim != 2: + raise ValueError( + "The given `objective_sense` implies that there are multiple objectives." + " In this case, `evals` was expected to have two dimensions." + f" However, the shape of `evals` is {evals.shape}." + ) + multi_objective = True + utils = pareto_utility(evals, objective_sense=objective_sense, crowdsort=crowdsort) + descending = True + else: + raise TypeError(f"Unrecognized `objective_sense`: {repr(objective_sense)}") + + if n is None: + if multi_objective: + raise ValueError("When there are multiple objectives, the number of solutions to take cannot be omitted.") + argbest = torch.argmax if descending else torch.argmin + best_index = argbest(utils) + return values[torch.as_tensor(best_index, device="cpu")], evals[best_index] + else: + indices_of_best = torch.argsort(utils, descending=descending)[:n] + best_rows = values[torch.as_tensor(indices_of_best, device="cpu")] + best_evals = torch.index_select(evals, 0, indices_of_best) + return best_rows, best_evals + + def take_best( - values: torch.Tensor, + values: Union[torch.Tensor, ObjectArray], evals: torch.Tensor, n: Optional[int] = None, *, @@ -1591,9 +2093,15 @@ def take_best( solutions to take. Like in the single-objective case, the decision values and the evaluation results of the taken solutions will be returned. + **Support for ObjectArray.** + This function supports decision values expressed via instances of + `ObjectArray`. + Args: - values: Decision values tensor, with at least 2 dimensions. - Extra leftmost dimensions will be taken as batch dimensions. + values: Decision values, expressed via a tensor with at least + 2 dimensions or via an `ObjectArray`. If given as a tensor, + extra leftmost dimensions will be interpreted as batch + dimensions. evals: Evaluation results tensor, with at least 1 dimension. Extra leftmost dimensions will be taken as batch dimensions. n: If left as None, the single best solution will be taken. @@ -1614,10 +2122,13 @@ def take_best( when deciding whether or not it is among the best `n` solutions. Returns: A tuple of the form `(decision_values, evaluation_results)`, where - `decision_values` is the decision values tensor for the taken - solution(s), and `evaluation_results` is the evaluation results tensor - for the taken solution(s). + `decision_values` is the decision values (as a tensor or as an + `ObjectArray`) for the taken solution(s), and `evaluation_results` + is the evaluation results tensor for the taken solution(s). """ + if isinstance(values, ObjectArray): + return _take_best_considering_objects(values, evals, n, objective_sense, crowdsort) + if isinstance(objective_sense, str): multi_objective = False elif isinstance(objective_sense, Iterable): diff --git a/src/evotorch/tools/immutable.py b/src/evotorch/tools/immutable.py index cccf5306..8efdc0f7 100644 --- a/src/evotorch/tools/immutable.py +++ b/src/evotorch/tools/immutable.py @@ -66,7 +66,12 @@ def as_immutable(x: Any, *, memo: Optional[dict] = None) -> Any: elif isinstance(x, torch.Tensor): result = x.clone().as_subclass(ReadOnlyTensor) elif isinstance(x, ObjectArray): - result = x.clone().get_read_only_view() + x = x.clone() + xlength = len(x) + result = ObjectArray(xlength) + for i in range(xlength): + result[i] = x[i] + result = result.get_read_only_view() elif isinstance(x, np.ndarray): if _numpy_array_stores_objects(x): result = ObjectArray(len(x)) @@ -76,7 +81,6 @@ def as_immutable(x: Any, *, memo: Optional[dict] = None) -> Any: else: result = x.copy() result.flags["WRITEABLE"] = False - result = result elif isinstance(x, Mapping): result = ImmutableDict(x, memo) elif isinstance(x, set): diff --git a/tests/test_func_ops.py b/tests/test_func_ops.py new file mode 100644 index 00000000..7831c2c4 --- /dev/null +++ b/tests/test_func_ops.py @@ -0,0 +1,740 @@ +# Copyright 2024 NNAISENSE SA +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Sequence, Union + +import pytest +import torch + +import evotorch.operators.functional as func_ops +from evotorch import testing +from evotorch.decorators import rowwise +from evotorch.tools import ObjectArray, make_tensor + + +def test_combine_single_obj(): + tolerance = 1e-8 + solution_length = 5 + batch_size = 3 + + popsizeA = 6 + populationA = ( + torch.arange(batch_size * popsizeA * solution_length) + .reshape(batch_size, popsizeA, solution_length) + .to(dtype=torch.float32) + ) + evalsA = populationA.sum(dim=-1) + + popsizeB = 3 + populationB = ( + torch.arange(batch_size * popsizeB * solution_length) + .reshape(batch_size, popsizeB, solution_length) + .to(dtype=torch.float32) + * -1 + ) + evalsB = populationB.sum(dim=-1) + + # test the unbatched case + unbatched_combined_pop, unbatched_combined_evals = func_ops.combine( + (populationA[0], evalsA[0]), (populationB[0], evalsB[0]) + ) + unbatched_expected_pop_shape = (popsizeA + popsizeB, solution_length) + unbatched_expected_evals_shape = popsizeA + popsizeB + testing.assert_shape_matches(unbatched_combined_pop, unbatched_expected_pop_shape) + testing.assert_shape_matches(unbatched_combined_evals, unbatched_expected_evals_shape) + testing.assert_allclose(unbatched_combined_evals, unbatched_combined_pop.sum(dim=-1), atol=tolerance) + + # test the batched case + expected_pop_shape = (batch_size, popsizeA + popsizeB, solution_length) + expected_evals_shape = ( + batch_size, + popsizeA + popsizeB, + ) + + combined_pop, combined_evals = func_ops.combine((populationA, evalsA), (populationB, evalsB)) + testing.assert_shape_matches(combined_pop, expected_pop_shape) + testing.assert_shape_matches(combined_evals, expected_evals_shape) + + for i_batch in range(batch_size): + testing.assert_allclose(combined_evals[i_batch], combined_pop[i_batch].sum(dim=-1), atol=tolerance) + + +def test_combine_multi_obj(): + tolerance = 1e-8 + solution_length = 5 + batch_size = 3 + + objective_sense = ["min", "min"] + num_objectives = len(objective_sense) + + @rowwise + def f(x: torch.Tensor) -> torch.Tensor: + return torch.cat([torch.sum(x).reshape(1), torch.min(x).reshape(1)]) + + popsizeA = 6 + populationA = ( + torch.arange(batch_size * popsizeA * solution_length) + .reshape(batch_size, popsizeA, solution_length) + .to(dtype=torch.float32) + ) + evalsA = f(populationA) + + popsizeB = 3 + populationB = ( + torch.arange(batch_size * popsizeB * solution_length) + .reshape(batch_size, popsizeB, solution_length) + .to(dtype=torch.float32) + * -1 + ) + evalsB = f(populationB) + + # test the unbatched case + unbatched_combined_pop, unbatched_combined_evals = func_ops.combine( + (populationA[0], evalsA[0]), (populationB[0], evalsB[0]), objective_sense=objective_sense + ) + unbatched_expected_pop_shape = (popsizeA + popsizeB, solution_length) + unbatched_expected_evals_shape = (popsizeA + popsizeB, num_objectives) + testing.assert_shape_matches(unbatched_combined_pop, unbatched_expected_pop_shape) + testing.assert_shape_matches(unbatched_combined_evals, unbatched_expected_evals_shape) + testing.assert_allclose(unbatched_combined_evals, f(unbatched_combined_pop), atol=tolerance) + + # test the batched case + expected_pop_shape = (batch_size, popsizeA + popsizeB, solution_length) + expected_evals_shape = (batch_size, popsizeA + popsizeB, num_objectives) + + combined_pop, combined_evals = func_ops.combine( + (populationA, evalsA), (populationB, evalsB), objective_sense=objective_sense + ) + testing.assert_shape_matches(combined_pop, expected_pop_shape) + testing.assert_shape_matches(combined_evals, expected_evals_shape) + + for i_batch in range(batch_size): + testing.assert_allclose(combined_evals[i_batch], f(combined_pop[i_batch]), atol=tolerance) + + +def test_combine_with_objects(): + tolerance = 1e-8 + + populationA = make_tensor( + [ + [1, 2], + [1, 3, 5], + [10, 20, 60, 40], + ], + dtype=object, + ) + + populationB = make_tensor( + [ + [-1, -2], + [-10, -20], + ], + dtype=object, + ) + + def f_single(x: ObjectArray) -> torch.Tensor: + return torch.as_tensor( + [sum(values) for values in x], + dtype=torch.float32, + ) + + def f_multi(x: ObjectArray) -> torch.Tensor: + return torch.as_tensor( + [[sum(values), min(values)] for values in x], + dtype=torch.float32, + ) + + for f, objective_sense in ((f_single, None), (f_multi, ["min", "min"])): + evalsA = f(populationA) + evalsB = f(populationB) + + combined_pop, combined_evals = func_ops.combine( + (populationA, evalsA), (populationB, evalsB), objective_sense=objective_sense + ) + + assert isinstance(combined_pop, ObjectArray) + assert isinstance(combined_evals, torch.Tensor) + + for i_solution in range(len(combined_pop)): + sln = combined_pop[i_solution : i_solution + 1] + sln_eval = combined_evals[i_solution : i_solution + 1] + sln_re_eval = f(sln) + testing.assert_allclose(sln_eval, sln_re_eval, atol=tolerance) + + +@pytest.mark.parametrize( + "population,desired_best_one,desired_best_two,obj_sense", + [ + # --- argument set --- + ( + # population + torch.FloatTensor( + [ + [1, 2, 3], + [100, 200, 300], + [10, 20, 30], + [-1, -2, -3], + ] + ), + # desired_best_one + torch.FloatTensor([100, 200, 300]), + # desired_best_two + torch.FloatTensor( + [ + [100, 200, 300], + [10, 20, 30], + ] + ), + # obj_sense + "max", + ), + # --- argument set --- + ( + # population + torch.FloatTensor( + [ + [1, 2, 3], + [100, 200, 300], + [10, 20, 30], + [-1, -2, -3], + ] + ), + # desired_best_one + torch.FloatTensor([-1, -2, -3]), + # desired_best_two + torch.FloatTensor( + [ + [-1, -2, -3], + [1, 2, 3], + ] + ), + # obj_sense + "min", + ), + # --- argument set --- + ( + # population + torch.FloatTensor( + [ + [ + [1, 2, 3], + [100, 200, 300], + [-1, -2, -3], + ], + [ + [5, 6, 7], + [8, 9, 10], + [20, 30, 40], + ], + ], + ), + # desired_best_one + torch.FloatTensor( + [ + [100, 200, 300], + [20, 30, 40], + ], + ), + # desired_best_two + torch.FloatTensor( + [ + [ + [100, 200, 300], + [1, 2, 3], + ], + [ + [20, 30, 40], + [8, 9, 10], + ], + ], + ), + # obj_sense + "max", + ), + # --- argument set --- + ( + # population + make_tensor( + [ + [1, 2, 3], + [100, 200, 300, 400], + [-1, -2], + ], + dtype=object, + ), + # desired_best_one + [100, 200, 300, 400], + # desired_best_two + [ + [100, 200, 300, 400], + [1, 2, 3], + ], + # obj_sense + "max", + ), + ], +) +def test_take_best(population, desired_best_one, desired_best_two, obj_sense): + tolerance = 1e-8 + + got_objects = isinstance(population, ObjectArray) + + if got_objects: + + def f(x: Union[ObjectArray, Sequence]) -> torch.Tensor: + if isinstance(x, ObjectArray): + return torch.as_tensor([sum(solution) for solution in x], dtype=torch.float32) + else: + return torch.as_tensor(sum(x), dtype=torch.float32) + + else: + + @rowwise + def f(x: torch.Tensor) -> torch.Tensor: + return torch.sum(x) + + evals = f(population) + best_one, best_one_eval = func_ops.take_best(population, evals, objective_sense=obj_sense) + testing.assert_allclose(best_one_eval, f(best_one), atol=tolerance) + if got_objects: + best_one = torch.as_tensor(list(best_one), dtype=torch.float32) + desired_best_one = torch.as_tensor(desired_best_one, dtype=torch.float32) + testing.assert_allclose(best_one, desired_best_one, atol=tolerance) + + best_two, best_two_evals = func_ops.take_best(population, evals, 2, objective_sense=obj_sense) + testing.assert_allclose(best_two_evals, f(best_two), atol=tolerance) + if got_objects: + for i in (0, 1): + row = torch.as_tensor(list(best_two[i]), dtype=torch.float32) + desired_row = torch.as_tensor(desired_best_two[i], dtype=torch.float32) + testing.assert_allclose(row, desired_row, atol=tolerance) + else: + testing.assert_allclose(best_two, desired_best_two, atol=tolerance) + + +@pytest.mark.parametrize( + "population,desired_best_two,obj_sense", + [ + # --- argument set --- + ( + # population + torch.FloatTensor( + [ + [1, 2, 3], + [100, 200, 300], + [-1, -9, 0], + [5, 6, 7], + [98, 200, 400], + [10, 11, 12], + ] + ), + # desired_best_two + torch.FloatTensor( + [ + [100, 200, 300], + [98, 200, 400], + ], + ), + # obj_sense + ["max", "max"], + ), + # --- argument set --- + ( + # population + torch.FloatTensor( + [ + [22, 4, -11], + [10, 5, 7], + [11, 6, 3], + [23, 4, -10], + [3, 6, 7], + ] + ), + # desired_best_two + torch.FloatTensor( + [ + [22, 4, -11], + [23, 4, -10], + ], + ), + # obj_sense + ["max", "min"], + ), + # --- argument set --- + ( + # population + torch.FloatTensor( + [ + [ + [1, 2, 3], + [100, 200, 300], + [-1, -9, 0], + [5, 6, 7], + [98, 200, 400], + [10, 11, 12], + ], + [ + [1, 2, 3], + [-100, -200, -300], + [-1, -9, 0], + [9, 11, 13], + [-98, -200, -400], + [10, 11, 12], + ], + ] + ), + # desired_best_two + torch.FloatTensor( + [ + [ + [100, 200, 300], + [98, 200, 400], + ], + [ + [9, 11, 13], + [10, 11, 12], + ], + ], + ), + # obj_sense + ["max", "max"], + ), + # --- argument set --- + ( + # population + make_tensor( + [ + [1, 50, 60, 2], + [23, 1, 2, 67], + [12, 5, 3, 8, 99], + [2, 55, 1], + [15, 17, 16, 23, 22], + ], + dtype=object, + ), + # desired_best_two + make_tensor( + [ + [1, 50, 60, 2], + [2, 55, 1], + ], + dtype=object, + ), + # obj_sense + ["min", "min"], + ), + ], +) +def test_take_best_with_multiobj(population, desired_best_two, obj_sense): + tolerance = 1e-8 + + if isinstance(population, ObjectArray): + + def f(solutions: ObjectArray) -> torch.Tensor: + num_solutions = len(solutions) + result = torch.empty(num_solutions, 2, dtype=torch.float32) + for i in range(num_solutions): + result[i, 0] = solutions[i][0] + result[i, 1] = solutions[i][-1] + return result + + def vertical_sum_of_decision_values(solutions: ObjectArray) -> torch.Tensor: + length = min([len(solution) for solution in solutions]) + result = torch.zeros(length, dtype=torch.float32) + for j in range(length): + result[j] = sum([solution[j] for solution in solutions]) + return result + + else: + + @rowwise + def f(x: torch.Tensor) -> torch.Tensor: + return torch.hstack([x[0].reshape(1), x[-1].reshape(1)]) + + def vertical_sum_of_decision_values(solutions: torch.Tensor) -> torch.Tensor: + return torch.sum(solutions, dim=-2) + + evals = f(population) + best_two, best_two_evals = func_ops.take_best(population, evals, 2, objective_sense=obj_sense) + best_two_total = vertical_sum_of_decision_values(best_two) + desired_best_two_total = vertical_sum_of_decision_values(desired_best_two) + + testing.assert_allclose(torch.sum(f(best_two), dim=-2), torch.sum(best_two_evals, dim=-2), atol=tolerance) + testing.assert_allclose(best_two_total, desired_best_two_total, atol=tolerance) + + +@pytest.mark.parametrize( + "population,num_tournaments,obj_sense", + [ + # --- argument set --- + ( + # population + torch.arange(200.0).reshape(20, 10), + # num_tournaments + 6, + # obj_sense + "max", + ), + # --- argument set --- + ( + # population + torch.arange(200.0).reshape(20, 10), + # num_tournaments + 6, + # obj_sense + ["max", "min"], + ), + # --- argument set --- + ( + # population + torch.arange(200.0).reshape(2, 10, 10), + # num_tournaments + 6, + # obj_sense + "min", + ), + # --- argument set --- + ( + # population + torch.arange(200.0).reshape(2, 10, 10), + # num_tournaments + 6, + # obj_sense + ["max", "min"], + ), + # --- argument set --- + ( + # population + make_tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + [-1, 2, 3, 4], + [5, -6, 7, 8], + [9, 10, -11, 12], + [13, -14, 15, -16], + [100, 101, 102, 103], + [33, 44, 55, 66], + ], + dtype=object, + ), + # num_tournaments + 6, + # obj_sense + "min", + ), + # --- argument set --- + ( + # population + make_tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + [-1, 2, 3, 4], + [5, -6, 7, 8], + [9, 10, -11, 12], + [13, -14, 15, -16], + [100, 101, 102, 103], + [33, 44, 55, 66], + ], + dtype=object, + ), + # num_tournaments + 6, + # obj_sense + ["max", "min"], + ), + ], +) +def test_tournament(population, num_tournaments, obj_sense): + got_objects = isinstance(population, ObjectArray) + + if isinstance(obj_sense, str): + multi_obj = False + num_objs = None + else: + multi_obj = True + num_objs = len(obj_sense) + + if multi_obj: + if got_objects: + + def f(x: ObjectArray) -> torch.Tensor: + n = len(x) + result = torch.empty(n, 2, dtype=torch.float32) + for i, solution in enumerate(x): + result[i, 0] = solution[0] + result[i, 1] = solution[1] + return result + + else: + + @rowwise + def f(x: torch.Tensor) -> torch.Tensor: + return torch.hstack([x[0].reshape(1), x[1].reshape(1)]) + + else: + if got_objects: + + def f(x: ObjectArray) -> torch.Tensor: + n = len(x) + result = torch.empty(n, dtype=torch.float32) + for i, solution in enumerate(x): + result[i] = sum([value**2 for value in solution]) + return result + + else: + + @rowwise + def f(x: torch.Tensor) -> torch.Tensor: + return torch.sum(x**2) + + if got_objects: + solution_length = len(population[0]) + batch_shape = tuple() + + def pop_shape(x: ObjectArray) -> tuple: + return len(x), len(x[0]) + + else: + solution_length = population.shape[-1] + batch_shape = population.shape[:-2] + + def pop_shape(x: torch.Tensor) -> tuple: + return x.shape + + evals = f(population) + + solutions = func_ops.tournament( + population, evals, num_tournaments=num_tournaments, tournament_size=2, objective_sense=obj_sense + ) + assert pop_shape(solutions) == tuple([*batch_shape, num_tournaments, solution_length]) + + solutions, sln_evals = func_ops.tournament( + population, + evals, + num_tournaments=num_tournaments, + tournament_size=2, + objective_sense=obj_sense, + with_evals=True, + ) + assert pop_shape(solutions) == tuple([*batch_shape, num_tournaments, solution_length]) + if multi_obj: + assert sln_evals.shape == tuple([*batch_shape, num_tournaments, num_objs]) + else: + assert sln_evals.shape == tuple([*batch_shape, num_tournaments]) + + parents1, parents2 = func_ops.tournament( + population, + evals, + num_tournaments=num_tournaments, + tournament_size=2, + objective_sense=obj_sense, + split_results=True, + ) + assert pop_shape(parents1) == tuple([*batch_shape, num_tournaments // 2, solution_length]) + assert pop_shape(parents2) == tuple([*batch_shape, num_tournaments // 2, solution_length]) + + parents1, parent_evals1, parents2, parent_evals2 = func_ops.tournament( + population, + evals, + num_tournaments=num_tournaments, + tournament_size=2, + objective_sense=obj_sense, + split_results=True, + with_evals=True, + ) + assert pop_shape(parents1) == tuple([*batch_shape, num_tournaments // 2, solution_length]) + assert pop_shape(parents2) == tuple([*batch_shape, num_tournaments // 2, solution_length]) + if multi_obj: + assert parent_evals1.shape == tuple([*batch_shape, num_tournaments // 2, num_objs]) + assert parent_evals2.shape == tuple([*batch_shape, num_tournaments // 2, num_objs]) + else: + assert parent_evals1.shape == tuple([*batch_shape, num_tournaments // 2]) + assert parent_evals2.shape == tuple([*batch_shape, num_tournaments // 2]) + + indices = func_ops.tournament( + population, + evals, + num_tournaments=num_tournaments, + tournament_size=2, + objective_sense=obj_sense, + return_indices=True, + ) + assert indices.shape == tuple([*batch_shape, num_tournaments]) + + indices1, indices2 = func_ops.tournament( + population, + evals, + num_tournaments=num_tournaments, + tournament_size=2, + objective_sense=obj_sense, + return_indices=True, + split_results=True, + ) + assert indices1.shape == tuple([*batch_shape, num_tournaments // 2]) + assert indices2.shape == tuple([*batch_shape, num_tournaments // 2]) + + +@pytest.mark.parametrize( + "input_shape,num_children,desired_output_shape", + [ + ((20, 30), 8, (8, 30)), + ((7, 20, 30), 8, (7, 8, 30)), + ((5, 7, 20, 30), 8, (5, 7, 8, 30)), + ((20, 30), None, (20, 30)), + ((7, 20, 30), None, (7, 20, 30)), + ((5, 7, 20, 30), None, (5, 7, 20, 30)), + ], +) +def test_cross_over(input_shape, num_children, desired_output_shape): + population = torch.randn(input_shape) + + @rowwise + def single_objective_f(x: torch.Tensor) -> torch.Tensor: + return torch.sum(x**2) + + @rowwise + def multi_objective_f(x: torch.Tensor) -> torch.Tensor: + return torch.cat([torch.sum(x**2).reshape(1), torch.sum(x).reshape(1)]) + + single_objective_evals = single_objective_f(population) + multi_objective_evals = multi_objective_f(population) + + for objective_sense, evals in [("min", single_objective_evals), (["min", "min"], multi_objective_evals)]: + cross_over_functions = [ + ("multi_point_cross_over", {"num_points": 1}), + ("one_point_cross_over", {}), + ("multi_point_cross_over", {"num_points": 2}), + ("two_point_cross_over", {}), + ("multi_point_cross_over", {"num_points": 3}), + ("simulated_binary_cross_over", {"eta": 10.0}), + ("simulated_binary_cross_over", {"eta": 20.0}), + ] + + for tournament_size in (2, 3, 4): + for cross_over_fn, cross_over_cfg in cross_over_functions: + output = getattr(func_ops, cross_over_fn)( + population, + evals, + tournament_size=tournament_size, + num_children=num_children, + objective_sense=objective_sense, + **cross_over_cfg, + ) + assert output.shape == desired_output_shape From cd792cc577258ac3860b596a76d672990db3087f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:47:52 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../Functional_API/functional_ops.ipynb | 30 +++++------ .../Functional_API/multiobj_batched_ops.ipynb | 52 +++++++++---------- 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/examples/notebooks/Functional_API/functional_ops.ipynb b/examples/notebooks/Functional_API/functional_ops.ipynb index c7973583..27ffe760 100644 --- a/examples/notebooks/Functional_API/functional_ops.ipynb +++ b/examples/notebooks/Functional_API/functional_ops.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "05a384ac-fa4d-434e-a8c0-367350dec224", + "id": "0", "metadata": {}, "source": [ "# Genetic algorithm with the help of functional operators\n", @@ -32,7 +32,7 @@ }, { "cell_type": "markdown", - "id": "0f9f42f4-43ad-4766-84ee-8a4fe2a9fe2b", + "id": "1", "metadata": {}, "source": [ "---\n", @@ -47,7 +47,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3dd28c51-e08b-43b3-8cf9-efecdce49203", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -59,7 +59,7 @@ }, { "cell_type": "markdown", - "id": "8379cbbb-084c-461e-889a-2b334d52c138", + "id": "3", "metadata": {}, "source": [ "Below, we have the implementations for the fitness functions `rastrigin` and `sphere`.\n", @@ -71,7 +71,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5f1d45aa-50a4-4bc8-90a9-ed75b1bb1fbd", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +86,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ea8ddfeb-48de-4eab-9296-9e7862950d77", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -97,7 +97,7 @@ }, { "cell_type": "markdown", - "id": "fd1fe365-ae1b-4898-91d0-cb7517ef5f84", + "id": "6", "metadata": {}, "source": [ "In this notebook, the variable `f` points to the fitness function whose value we want to minimize:" @@ -106,7 +106,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fc6d3958-61e7-4838-a9c8-0c9c3d0e343c", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -116,7 +116,7 @@ }, { "cell_type": "markdown", - "id": "da899fd2-5bfe-41bd-8c76-20d8b5c2bf71", + "id": "8", "metadata": {}, "source": [ "Various hyperparameters and problem settings:" @@ -125,7 +125,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6625c172-0ec2-492e-9851-0a60594d9086", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -146,7 +146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36782df7-9728-44e2-ab24-312a0e83b0d2", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -217,7 +217,7 @@ }, { "cell_type": "markdown", - "id": "84b24feb-1e84-4de6-ae24-ca277f5f5f4d", + "id": "11", "metadata": {}, "source": [ "Decision values of the final population:" @@ -226,7 +226,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21898601-e05d-489f-8cc1-3beef8bdd33b", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -235,7 +235,7 @@ }, { "cell_type": "markdown", - "id": "a067340a-5c87-4406-82f4-9cdcccaa3349", + "id": "13", "metadata": {}, "source": [ "Best solution of the final population:" @@ -244,7 +244,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28fccbc1-c605-4b3f-9545-9e4de01b7f07", + "id": "14", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/notebooks/Functional_API/multiobj_batched_ops.ipynb b/examples/notebooks/Functional_API/multiobj_batched_ops.ipynb index 3c9e6ddc..d1909893 100644 --- a/examples/notebooks/Functional_API/multiobj_batched_ops.ipynb +++ b/examples/notebooks/Functional_API/multiobj_batched_ops.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "3a0bd5c3-661c-4b5d-9960-6f679ac6e1c7", + "id": "0", "metadata": {}, "source": [ "# Multiobjective optimization via functional operators API\n", @@ -13,7 +13,7 @@ }, { "cell_type": "markdown", - "id": "f37a843a-5d3e-4bdd-a743-244c156d1408", + "id": "1", "metadata": {}, "source": [ "---\n", @@ -24,7 +24,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fda7ae3a-970b-4803-a222-35cf0550bf19", + "id": "2", "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ }, { "cell_type": "markdown", - "id": "3019da0f-4545-4cbb-bf28-ee0fc5ea2f62", + "id": "3", "metadata": {}, "source": [ "Below, we implement Kursawe's function.\n", @@ -50,7 +50,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0455f2df-c438-470c-a9e2-c35086c8eed3", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -73,7 +73,7 @@ }, { "cell_type": "markdown", - "id": "9b904d3b-8008-4782-9d8d-b9be3f7ec770", + "id": "5", "metadata": {}, "source": [ "Below, we have the constants regarding the problem, and hyperparameters:" @@ -82,7 +82,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9c634129-8b1e-4050-a618-d5cce7b6843f", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -100,7 +100,7 @@ }, { "cell_type": "markdown", - "id": "b29688fc-73ae-4a07-8377-b2c7ddeb9640", + "id": "7", "metadata": {}, "source": [ "Initialize a population, and store it via the variable `population`:" @@ -109,7 +109,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ab53d77a-c30a-4ed3-ab05-6969a9933280", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +119,7 @@ }, { "cell_type": "markdown", - "id": "4c7b7c69-d728-4a29-8b5e-6dbb26b4f7b8", + "id": "9", "metadata": {}, "source": [ "Evaluate the initial population, and store the evaluation results within the variable `evals`:" @@ -128,7 +128,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a3d63e29-6849-4593-bc18-49d200a1298a", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -138,7 +138,7 @@ }, { "cell_type": "markdown", - "id": "8022fda7-0d73-432e-afe3-d5c1c2ca7d12", + "id": "11", "metadata": {}, "source": [ "Main loop of the optimization:" @@ -147,7 +147,7 @@ { "cell_type": "code", "execution_count": null, - "id": "378e59ee-31f7-41e0-b3e1-a9059e4c6aef", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -206,7 +206,7 @@ }, { "cell_type": "markdown", - "id": "7fae1133-c0a5-4867-ad65-58bd4d25cb77", + "id": "13", "metadata": {}, "source": [ "Considering that `evals` now stores the evaluation results of the latest population, we can take the best solutions belonging to the best pareto-front as follows:" @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a26ede1-490d-4925-914a-475c731b119c", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -239,7 +239,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4349eb26-1d33-403b-8f59-b9a3c506f4dc", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +248,7 @@ }, { "cell_type": "markdown", - "id": "6027aa8a-cc0a-433a-b0de-2956137becaa", + "id": "16", "metadata": {}, "source": [ "Plot the fitnesses of the best solutions:" @@ -257,7 +257,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f2e29903-a3e5-4f2a-9f35-2e73fd5988ec", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -268,7 +268,7 @@ }, { "cell_type": "markdown", - "id": "6d5be444-bad7-42eb-be8f-3bdb57f358c0", + "id": "18", "metadata": {}, "source": [ "---\n", @@ -283,7 +283,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8f38b144-9bf3-4d4e-a266-14efbe909905", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -305,7 +305,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1c82431e-e9f4-4d6e-9dba-b4a711139480", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -328,7 +328,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24f5a201-7d33-4bee-a569-dbfa97d0efd3", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -373,7 +373,7 @@ }, { "cell_type": "markdown", - "id": "66418331-30da-48b0-a2ea-425ba8b30730", + "id": "22", "metadata": {}, "source": [ "For each solution within each population, compute the domination count:" @@ -382,7 +382,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4f7cd61d-0396-4e52-a4a2-e793791991c7", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -392,7 +392,7 @@ }, { "cell_type": "markdown", - "id": "403568ab-5463-4c4a-bccc-3dd918eb60a5", + "id": "24", "metadata": {}, "source": [ "From each population, take the best pareto-front, and plot the fitnesses belonging to that pareto-front:" @@ -401,7 +401,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ccc9ffcc-06ff-44b0-998f-5874468c2383", + "id": "25", "metadata": {}, "outputs": [], "source": [ From 8f3506ce93b5c0eed11f4f77b8c64a481ce14d3a Mon Sep 17 00:00:00 2001 From: Nihat Engin Toklu Date: Fri, 30 Aug 2024 18:13:22 +0200 Subject: [PATCH 5/6] Add an example for the functional API and a bugfix This commit adds a new example notebook which demonstrates how one can use the functional operators API to solve problems where a solution is expressed via objects. A bugfix is also introduced for the method `make_callable_evaluator` for problems where the dtype is set as object. --- examples/notebooks/Functional_API/README.md | 1 + .../notebooks/Functional_API/func_rl_ga.ipynb | 898 ++++++++++++++++++ src/evotorch/core.py | 5 +- 3 files changed, 903 insertions(+), 1 deletion(-) create mode 100644 examples/notebooks/Functional_API/func_rl_ga.ipynb diff --git a/examples/notebooks/Functional_API/README.md b/examples/notebooks/Functional_API/README.md index d315c6f0..697b7118 100644 --- a/examples/notebooks/Functional_API/README.md +++ b/examples/notebooks/Functional_API/README.md @@ -7,5 +7,6 @@ Here are the examples demonstrating various features of this functional API: - **[Maintaining a batch of populations using the functional EvoTorch API](batched_searches.ipynb)**: This notebook shows how one can efficiently run multiple searches simultaneously, each with its own population and hyperparameter configuration, by maintaining a batch of populations. - **[Functional genetic algorithm operators](functional_ops.ipynb)**: This notebook shows how one can implement a custom genetic algorithm by combining the genetic algorithm operator implementations provided by the functional API of EvoTorch. - **[Functional operators for multi-objective optimization](multiobj_batched_ops.ipynb)**: This notebook shows how one can use the functional operators of EvoTorch for multi-objective optimization. Additionally, batched optimization capabilities of these operators are demonstrated. +- **[Functional operators for solving problems with non-numeric solutions](func_rl_ga.ipynb)**: This notebook demonstrates how one can use the functional operators of EvoTorch for solving a problem where a solution is not expressed via a fixed-length numeric vector, but via objects (such as lists, dictionaries, etc.). In more details, this example focuses on a neuro-evolutionary reinforcement learning problem, where each policy is encoded via a dictionary. - **[Solving constrained optimization problems](constrained.ipynb)**: EvoTorch provides batching-friendly constraint penalization functions that can be used with both the object-oriented API and the functional API. In addition, these constraint penalization functions can be used with gradient-based optimization. This notebook demonstrates these features. - **[Solving reinforcement learning tasks using functional evolutionary algorithms](problem.ipynb)**: The functional evolutionary algorithm implementations of EvoTorch can be used to solve problems that are expressed using the object-oriented core API of EvoTorch. To demonstrate this, this notebook instantiates a `GymNE` problem for the reinforcement learning task "CartPole-v1", and solves it using the functional `pgpe` implementation. diff --git a/examples/notebooks/Functional_API/func_rl_ga.ipynb b/examples/notebooks/Functional_API/func_rl_ga.ipynb new file mode 100644 index 00000000..73fd3265 --- /dev/null +++ b/examples/notebooks/Functional_API/func_rl_ga.ipynb @@ -0,0 +1,898 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f9dc9f5e-802f-4305-9919-bd2202f8a63b", + "metadata": {}, + "source": [ + "# Evolving objects using the functional operators API of EvoTorch\n", + "\n", + "In this notebook, we show how to use the functional operators API of EvoTorch for tackling a problem with non-numeric solutions.\n", + "\n", + "In the problem we consider, the goal is to evolve parameter tensors of a feed-forward neural network to make a simulated `Ant-v4` MuJoCo robot walk forward.\n", + "The feed-forward neural network policy has the following modules:\n", + "\n", + "- module 0: linear transformation (`torch.nn.Linear`) with a **weight** matrix and with a **bias** vector\n", + "- module 1: tanh (`torch.nn.Tanh`)\n", + "- module 2: linear transformation (`torch.nn.Linear`) with a **weight** matrix and with a **bias** vector\n", + "\n", + "In this problem, instead of a fixed-length vector consisting of real numbers, a solution is represented by a dictionary structured like this:\n", + "\n", + "```\n", + "{\n", + " \"0.weight\": [ ... list of seeds ... ],\n", + " \"0.bias\": [ ... list of seeds ... ],\n", + " \"2.weight\": [ ... list of seeds ... ],\n", + " \"2.bias\": [ ... list of seeds ... ],\n", + "}\n", + "```\n", + "\n", + "where each key is a name referring to a parameter tensor. Associated with each key is a list of integers (integers being random seeds). At the moment of decoding a solution, each parameter tensor (e.g. `\"0.weight\"`) is constructed by sampling a Gaussian noise using each seed, and then by summing those Gaussian noises (as was done in `[1]` and `[2]`).\n", + "\n", + "**Note 1:** Although this example is inspired by the studies `[1]` and `[2]`, it is not a faithful implementation of any them. Instead, this notebook focuses on demonstrating various features of the functional operators API of EvoTorch.\n", + "\n", + "**Note 2:** For the sake of simplicity, the action space of `Ant-v4` is binned. With this simplification and with its default hyperparameters, this example evolutionary algorithm is able to find gaits for the ant robot with a relatively small population size, although the evolved gaits will not be very efficient (i.e. non-competitive cumulative rewards).\n", + "\n", + "---\n", + "\n", + "`[1]` Felipe Petroski Such, Vashisht Madhavan, Edoardo Conti, Joel Lehman, Kenneth O. Stanley, Jeff Clune (2017). \"Deep neuroevolution: Genetic algorithms are a competitive alternative for training deep neural networks for reinforcement learning.\" arXiv preprint arXiv:1712.06567.\n", + "\n", + "`[2]` Risi, Sebastian, and Kenneth O. Stanley (2019). \"Deep neuroevolution of recurrent and discrete world models.\" Proceedings of the Genetic and Evolutionary Computation Conference.\n", + "\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "8df623bd-3cc6-44d8-888f-fe16460ffa9a", + "metadata": {}, + "source": [ + "## Summary of the evolutionary algorithm\n", + "\n", + "We implement a simple, elitist genetic algorithm with tournament selection, cross-over, and mutation operators. The main ideas of this genetic algorithm are as follows.\n", + "\n", + "**Generation of a new solution:**\n", + "- Make a new dictionary.\n", + "- Associated with each key (parameter name) within the dictionary, make a single-element list of seeds, the seed within it being a random integer.\n", + "\n", + "**Cross-over between two solutions.**\n", + "- Make two children solutions (dictionaries).\n", + "- For each key (parameter name):\n", + " - Sample a real number $p$ between 0 and 1.\n", + " - If $p < 0.5$, the first child receives its list of seeds from the first parent, the second child receives its list of seeds from the second parent.\n", + " - Otherwise, the first child receives its list of seeds from the second parent, the second child receives its list of seeds from the first parent.\n", + "\n", + "**Mutation of an existing solution.**\n", + "- Pick a key (parameter name) within the solution (dictionary).\n", + "- Randomly sample a new integer, and add this integer into the list of seeds associated with the picked key." + ] + }, + { + "cell_type": "markdown", + "id": "510e620e-a815-48c8-901b-7eeba7781786", + "metadata": {}, + "source": [ + "## Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a2b0b8b-7bc6-49e3-ac27-21aba81cb2df", + "metadata": {}, + "outputs": [], + "source": [ + "from evotorch import Problem, Solution\n", + "from evotorch.tools import make_tensor, ObjectArray\n", + "import evotorch.operators.functional as func_ops\n", + "\n", + "import gymnasium as gym\n", + "import numpy as np\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch.func import functional_call\n", + "\n", + "from typing import Iterable, Mapping, Optional, Union\n", + "import random\n", + "import os\n", + "from datetime import datetime\n", + "import pickle" + ] + }, + { + "cell_type": "markdown", + "id": "e0b522cd-4ce1-4fc5-a0a7-f066bb1a2a04", + "metadata": {}, + "source": [ + "The function below takes a series of seeds, and makes a tensor of real numbers out of them.\n", + "We will use this function at the moment of decoding a solution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "210a3488-9370-4a87-bf8f-46f059379243", + "metadata": {}, + "outputs": [], + "source": [ + "def make_tensor_from_seeds(\n", + " like: torch.Tensor,\n", + " seeds: Iterable,\n", + " *,\n", + " mutation_power: float,\n", + " mutation_decay: float,\n", + " min_mutation_power: float,\n", + ") -> torch.Tensor:\n", + " \"\"\"\n", + " Take a series of seeds and compute a tensor out of them.\n", + "\n", + " Args:\n", + " like: A source tensor. The resulting tensor will have the same shape,\n", + " dtype, and device with this source tensor.\n", + " seeds: An iterable in which each item is an integer, each integer\n", + " being a random seed.\n", + " mutation_power: A multiplier for the Gaussian noise generated out of\n", + " a random seed.\n", + " mutation_decay: For each seed, the mutation power will be multiplied\n", + " by this factor. For example, if this multiplier is 0.9, the power\n", + " of the mutation will be decreased with each seed, as that power\n", + " will be diminished by getting multiplied with 0.9.\n", + " min_mutation_power: To prevent the mutation power from getting to\n", + " close to 0, provide a lower bound multiplier via this argument.\n", + " Returns:\n", + " The tensor generated from the given seeds.\n", + " \"\"\"\n", + " from numpy.random import RandomState\n", + "\n", + " result = torch.zeros_like(like)\n", + " for i_seed, seed in enumerate(seeds):\n", + " multiplier = max(mutation_power * (mutation_decay ** i_seed), min_mutation_power)\n", + " result += (\n", + " multiplier * torch.as_tensor(RandomState(seed).randn(*(like.shape)), dtype=like.dtype, device=like.device)\n", + " )\n", + "\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "id": "7cdb5a99-7372-43a3-a7a1-07ec892c6192", + "metadata": {}, + "source": [ + "Helper function to generate a random seed integer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe325334-8f42-41d8-98bb-a8db25980da4", + "metadata": {}, + "outputs": [], + "source": [ + "def sample_seed() -> int:\n", + " return random.randint(0, (2 ** 32) - 1)" + ] + }, + { + "cell_type": "markdown", + "id": "5808a486-5178-4409-9ba2-b0a34bf92dff", + "metadata": {}, + "source": [ + "**Observation normalization.**\n", + "Below, we have helper functions that will generate observation data for the reinforcement learning environment at hand.\n", + "The observation data will be used for normalizing the observations before passing them to the policy neural network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06b496a8-9874-44d0-859f-b67086f62474", + "metadata": {}, + "outputs": [], + "source": [ + "def env_name_to_file_name(env_name: str) -> str:\n", + " \"\"\"\n", + " Convert the gymnasium environment ID to a more file-name-friendly counterpart.\n", + "\n", + " The character ':' in the input string will be replaced with '__colon__'.\n", + " Similarly, the character '/' in the input string will be replaced with '__slash__'.\n", + "\n", + " Args:\n", + " env_name: gymnasium environment ID\n", + " Returns:\n", + " File-name-friendly counterpart of the input string.\n", + " \"\"\"\n", + " result = env_name\n", + " result = result.replace(\":\", \"__colon__\")\n", + " result = result.replace(\"/\", \"__slash__\")\n", + " return result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abf2e4c9-abe7-4a8b-b844-a07729754458", + "metadata": {}, + "outputs": [], + "source": [ + "def create_obs_data(\n", + " *,\n", + " env_name: str,\n", + " num_timesteps: int,\n", + " report_interval: Union[int, float] = 5,\n", + " seed: int = 0,\n", + ") -> tuple:\n", + " \"\"\"\n", + " Create observation normalization data with the help of random actions.\n", + "\n", + " This function creates a gymnasium environment from the given `env_name`.\n", + " Then, it keeps sending random actions to this environment, and collects stats from the observations.\n", + "\n", + " Args:\n", + " env_name: ID of the gymnasium environment\n", + " num_timesteps: For how many timesteps will the function operate on the environment\n", + " report_interval: Time interval, in seconds, for reporting the status\n", + " seed: A seed that will be used for regulating the randomness of both the environment\n", + " and of the random actions.\n", + " Returns:\n", + " A tuple of the form `(mean, stdev)`, where `mean` is the elementwise mean of the observation vectors,\n", + " and `stdev` is the elementwise standard deviation of the observation vectors.\n", + " \"\"\"\n", + " print(\"Creating observation data for\", env_name)\n", + "\n", + " class accumulated:\n", + " sum: Optional[np.ndarray] = None\n", + " sum_of_squares: Optional[np.ndarray] = None\n", + " count: int = 0\n", + "\n", + " def accumulate(obs: np.ndarray):\n", + " if accumulated.sum is None:\n", + " accumulated.sum = obs.copy()\n", + " else:\n", + " accumulated.sum += obs\n", + "\n", + " squared = obs ** 2\n", + " if accumulated.sum_of_squares is None:\n", + " accumulated.sum_of_squares = squared\n", + " else:\n", + " accumulated.sum_of_squares += squared\n", + "\n", + " accumulated.count += 1\n", + "\n", + " rndgen = np.random.RandomState(seed)\n", + "\n", + " env = gym.make(env_name)\n", + " assert isinstance(env.action_space, gym.spaces.Box), \"Can only work with Box action spaces\"\n", + "\n", + " def reset_env() -> tuple:\n", + " return env.reset(seed=rndgen.randint(2 ** 32))\n", + "\n", + " action_gap = env.action_space.high - env.action_space.low\n", + " def sample_action() -> np.ndarray:\n", + " return (rndgen.rand(*(env.action_space.shape)) * action_gap) + env.action_space.low\n", + "\n", + " observation, _ = reset_env()\n", + " accumulate(observation)\n", + "\n", + " last_report_time = datetime.now()\n", + "\n", + " for t in range(num_timesteps):\n", + " action = sample_action()\n", + " observation, _, terminated, truncated, _ = env.step(action)\n", + " accumulate(observation)\n", + "\n", + " done = terminated | truncated\n", + " if done:\n", + " observation, info = reset_env()\n", + " accumulate(observation)\n", + "\n", + " tnow = datetime.now()\n", + " if (tnow - last_report_time).total_seconds() > report_interval:\n", + " print(\"Number of timesteps:\", t, \"/\", num_timesteps)\n", + " last_report_time = tnow\n", + "\n", + " E_x = accumulated.sum / accumulated.count\n", + " E_x2 = accumulated.sum_of_squares / accumulated.count\n", + "\n", + " mean = E_x\n", + " variance = np.maximum(E_x2 - ((E_x) ** 2), 1e-2)\n", + " stdev = np.sqrt(variance)\n", + "\n", + " print(\"Done.\")\n", + " \n", + " return mean, stdev" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fde9767-b959-450b-896b-6a5e129f329f", + "metadata": {}, + "outputs": [], + "source": [ + "def get_obs_data(env_name: str, num_timesteps: int = 50000, seed: int = 0) -> tuple:\n", + " \"\"\"\n", + " Generate observation normalization data for the gymnasium environment whose name is given.\n", + "\n", + " If such normalization data was already generated and saved into a pickle file, that pickle file will be loaded.\n", + " Otherwise, new normalization data will be generated and saved into a new pickle file.\n", + "\n", + " Args:\n", + " env_name: ID of the gymnasium environment\n", + " num_timesteps: For how many timesteps will the observation collector operate on the environment\n", + " seed: A seed that will be used for regulating the randomness of both the environment\n", + " and of the random actions.\n", + " Returns:\n", + " A tuple of the form `(mean, stdev)`, where `mean` is the elementwise mean of the observation vectors,\n", + " and `stdev` is the elementwise standard deviation of the observation vectors.\n", + " \"\"\"\n", + " num_timesteps = int(num_timesteps)\n", + " envfname = env_name_to_file_name(env_name)\n", + " fname = f\"obs_seed{seed}_t{num_timesteps}_{envfname}.pickle\"\n", + " if os.path.isfile(fname):\n", + " with open(fname, \"rb\") as f:\n", + " return pickle.load(f)\n", + " else:\n", + " obsdata = create_obs_data(env_name=env_name, num_timesteps=num_timesteps, seed=seed)\n", + " with open(fname, \"wb\") as f:\n", + " pickle.dump(obsdata, f)\n", + " return obsdata" + ] + }, + { + "cell_type": "markdown", + "id": "8f782f05-93b6-4d80-9679-e3b6028428e3", + "metadata": {}, + "source": [ + "**Problem definition.**\n", + "Below is the problem definition for the considered reinforcement learning task.\n", + "We are defining the problem as a subclass of `evotorch.Problem`, so that we will be able to use ray-based parallelization capabilities of the base `Problem` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c8216b77-0f64-4f9d-8a3d-44cfe3deaec5", + "metadata": {}, + "outputs": [], + "source": [ + "class MyRLProblem(Problem):\n", + " def __init__(\n", + " self,\n", + " *,\n", + " env_name: str,\n", + " obs_mean: Optional[np.ndarray] = None,\n", + " obs_stdev: Optional[np.ndarray] = None,\n", + " mutation_power: float = 0.5,\n", + " mutation_decay: float = 0.9,\n", + " min_mutation_power: float = 0.05,\n", + " hidden_sizes: tuple = (64,),\n", + " bin: Optional[float] = 0.2,\n", + " num_episodes: int = 4,\n", + " episode_length: Optional[int] = None,\n", + " decrease_rewards_by: Optional[float] = 1.0,\n", + " num_actors: Optional[Union[int, str]] = \"max\"\n", + " ):\n", + " super().__init__(\n", + " objective_sense=\"max\",\n", + " dtype=object,\n", + " num_actors=num_actors,\n", + " )\n", + " self._env_name = str(env_name)\n", + " self._env = None\n", + " self._hidden_sizes = [int(hidden_size) for hidden_size in hidden_sizes]\n", + " self._policy = None\n", + "\n", + " self._obs_mean = None if obs_mean is None else np.asarray(obs_mean).astype(\"float32\")\n", + " self._obs_stdev = None if obs_mean is None else np.asarray(obs_stdev).astype(\"float32\")\n", + " self._mutation_power = float(mutation_power)\n", + " self._mutation_decay = float(mutation_decay)\n", + " self._min_mutation_power = float(min_mutation_power)\n", + " self._bin = None if bin is None else float(bin)\n", + " self._num_episodes = int(num_episodes)\n", + " self._episode_length = None if episode_length is None else int(episode_length)\n", + " self._decrease_rewards_by = None if decrease_rewards_by is None else float(decrease_rewards_by)\n", + "\n", + " def _get_policy(self) -> nn.Module:\n", + " env = self._get_env()\n", + "\n", + " if not isinstance(env.observation_space, gym.spaces.Box):\n", + " raise TypeError(\n", + " f\"Only Box-typed environments are supported. Encountered observation space is {env.observation_space}\"\n", + " )\n", + "\n", + " [obslen] = env.observation_space.shape\n", + " if isinstance(env.action_space, gym.spaces.Box):\n", + " [actlen] = env.action_space.shape\n", + " elif isinstance(env.action_space, gym.spaces.Discrete):\n", + " actlen = env.action_space.n\n", + " else:\n", + " raise TypeError(f\"Unrecognized action space: {env.action_space}\")\n", + "\n", + " all_sizes = [obslen]\n", + " all_sizes.extend(self._hidden_sizes)\n", + " all_sizes.append(actlen)\n", + "\n", + " last_size_index = len(all_sizes) - 1\n", + "\n", + " modules = []\n", + " for i in range(1, len(all_sizes)):\n", + " modules.append(nn.Linear(all_sizes[i - 1], all_sizes[i]))\n", + " if i < last_size_index:\n", + " modules.append(nn.Tanh())\n", + "\n", + " return nn.Sequential(*modules)\n", + "\n", + " def _get_env(self, visualize: bool = False) -> gym.Env:\n", + " if visualize:\n", + " return gym.make(self._env_name, render_mode=\"human\")\n", + "\n", + " if self._env is None:\n", + " self._env = gym.make(self._env_name)\n", + " return self._env\n", + "\n", + " def _generate_single_solution(self) -> dict:\n", + " policy = self._get_policy()\n", + " result = {}\n", + " for param_name, params in policy.named_parameters():\n", + " result[param_name] = [sample_seed()]\n", + " return result\n", + "\n", + " def generate_values(self, n: int) -> ObjectArray:\n", + " return make_tensor([self._generate_single_solution() for _ in range(n)], dtype=object)\n", + "\n", + " def run_solution(\n", + " self,\n", + " x: Union[Mapping, Solution],\n", + " *,\n", + " num_episodes: Optional[int] = None,\n", + " visualize: bool = False\n", + " ) -> float:\n", + " if num_episodes is None:\n", + " num_episodes = self._num_episodes\n", + "\n", + " if isinstance(x, Mapping):\n", + " sln = x\n", + " elif isinstance(x, Solution):\n", + " sln = x.values\n", + " else:\n", + " raise TypeError(f\"Expected a Mapping or a Solution, but got {repr(x)}\")\n", + "\n", + " policy = self._get_policy()\n", + "\n", + " params = {}\n", + " for param_name, param_values in policy.named_parameters():\n", + " param_seeds = sln[param_name]\n", + " params[param_name] = make_tensor_from_seeds(\n", + " param_values,\n", + " param_seeds,\n", + " mutation_power=self._mutation_power,\n", + " mutation_decay=self._mutation_decay,\n", + " min_mutation_power=self._mutation_power,\n", + " )\n", + "\n", + " env = self._get_env(visualize=visualize)\n", + "\n", + " def use_policy(policy_input: np.ndarray) -> Union[int, np.ndarray]:\n", + " if (self._obs_mean is not None) and (self._obs_stdev is not None):\n", + " policy_input = policy_input - self._obs_mean\n", + " policy_input = policy_input / self._obs_stdev\n", + "\n", + " result = functional_call(policy, params, torch.as_tensor(policy_input, dtype=torch.float32)).numpy()\n", + "\n", + " if isinstance(env.action_space, gym.spaces.Box):\n", + " if self._bin is not None:\n", + " result = np.sign(result) * self._bin\n", + " result = np.clip(result, env.action_space.low, env.action_space.high)\n", + " elif isinstance(env.action_space, gym.spaces.Discrete):\n", + " result = int(np.argmax(result))\n", + " else:\n", + " raise TypeError(f\"Unrecognized action space: {repr(env.action_space)}\")\n", + "\n", + " return result\n", + "\n", + " cumulative_reward = 0.0\n", + "\n", + " for _ in range(num_episodes):\n", + " timestep = 0\n", + " observation, info = env.reset()\n", + " while True:\n", + " action = use_policy(observation)\n", + " observation, reward, done1, done2, _ = env.step(action)\n", + " timestep += 1\n", + " if (self._decrease_rewards_by is not None) and (not visualize):\n", + " reward = reward - self._decrease_rewards_by\n", + " cumulative_reward += reward\n", + " if (\n", + " done1\n", + " or done2\n", + " or (\n", + " (not visualize)\n", + " and (self._episode_length is not None)\n", + " and (timestep >= self._episode_length)\n", + " )\n", + " ):\n", + " break\n", + "\n", + " return cumulative_reward / num_episodes\n", + "\n", + " def visualize(self, x: Union[Solution, Mapping]) -> float:\n", + " return self.run_solution(x, num_episodes=1, visualize=True)\n", + " \n", + " def _evaluate(self, x: Solution):\n", + " x.set_evaluation(self.run_solution(x))" + ] + }, + { + "cell_type": "markdown", + "id": "394066c3-6eeb-44d1-baa0-fcae4e2101ef", + "metadata": {}, + "source": [ + "We now define our mutation and cross-over operators, via the functions `mutate` and `cross_over`.\n", + "Since the solutions are expressed via dictionary-like objects, we use `Mapping` for type annotations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "676714c9-6288-459f-aeef-411849a2af2b", + "metadata": {}, + "outputs": [], + "source": [ + "def mutate(solution: Mapping) -> Mapping:\n", + " from evotorch.tools import as_immutable\n", + "\n", + " solution = {k: list(v) for k, v in solution.items()}\n", + "\n", + " keys = list(solution.keys())\n", + " chosen_key = random.choice(keys)\n", + " solution[chosen_key].append(sample_seed())\n", + "\n", + " return as_immutable(solution)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62cf3114-83b0-48d3-be2e-3dd8387d8c32", + "metadata": {}, + "outputs": [], + "source": [ + "def cross_over(parent1: Mapping, parent2: Mapping) -> tuple:\n", + " from evotorch.tools import as_immutable\n", + "\n", + " keys = list(parent1.keys())\n", + "\n", + " child1 = {}\n", + " child2 = {}\n", + " for k in keys:\n", + " p = random.random()\n", + " if p < 0.5:\n", + " child1[k] = parent1[k]\n", + " child2[k] = parent2[k]\n", + " else:\n", + " child1[k] = parent2[k]\n", + " child2[k] = parent1[k]\n", + "\n", + " return as_immutable(child1), as_immutable(child2)" + ] + }, + { + "cell_type": "markdown", + "id": "51178388-1db5-4864-8012-607558f7a151", + "metadata": {}, + "source": [ + "ID of the considered reinforcement learning task:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68f3b4f1-9120-4d0c-a314-cff8a6a1ad93", + "metadata": {}, + "outputs": [], + "source": [ + "ENV_NAME = \"Ant-v4\"" + ] + }, + { + "cell_type": "markdown", + "id": "f2f43b33-d6a1-4fd0-97bb-750a8e3c6667", + "metadata": {}, + "source": [ + "Generate or load observation data for the considered reinforcement learning environment:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35d310cd-fca4-4887-8dd8-0b586efd7e48", + "metadata": {}, + "outputs": [], + "source": [ + "env_obs_mean, env_obs_stdev = get_obs_data(ENV_NAME)\n", + "env_obs_mean, env_obs_stdev" + ] + }, + { + "cell_type": "markdown", + "id": "6e6ef3aa-1f8f-4ee1-807c-0770dbb0312f", + "metadata": {}, + "source": [ + "Instantiate the problem object:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c89d035f-c24f-45cb-ae3f-af2ffd49377d", + "metadata": {}, + "outputs": [], + "source": [ + "problem = MyRLProblem(\n", + " env_name=ENV_NAME,\n", + " decrease_rewards_by=1.0,\n", + " episode_length=250,\n", + " bin=0.15,\n", + " obs_mean=env_obs_mean,\n", + " obs_stdev=env_obs_stdev,\n", + ")\n", + "\n", + "problem" + ] + }, + { + "cell_type": "markdown", + "id": "ff43c026-6588-4b38-841c-677bd4faf0e5", + "metadata": {}, + "source": [ + "Out of the instantiated problem object, we make a callable evaluator named `f`.\n", + "The resulting object `f` can be used as a fitness function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc9c6588-535a-4e3e-90f4-e994a730068b", + "metadata": {}, + "outputs": [], + "source": [ + "f = problem.make_callable_evaluator()\n", + "f" + ] + }, + { + "cell_type": "markdown", + "id": "78b7f831-1202-44d3-88de-dbad81f1d7df", + "metadata": {}, + "source": [ + "Helper function for converting a real number to a string.\n", + "We will use this while reporting the status of the evolution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbd919a7-ca9d-4c13-90ef-d6cddb23faef", + "metadata": {}, + "outputs": [], + "source": [ + "def number_to_str(x) -> str:\n", + " return \"%.2f\" % float(x)" + ] + }, + { + "cell_type": "markdown", + "id": "5f1cc2b1-7d2d-44d7-8741-e523c2a01050", + "metadata": {}, + "source": [ + "Hyperparameters and constants for the evolutionary algorithm:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7c2b6eb-d1d4-4b54-a3db-7125dcb9725c", + "metadata": {}, + "outputs": [], + "source": [ + "popsize = 16\n", + "tournament_size = 4\n", + "objective_sense = problem.objective_sense\n", + "num_generations = 100" + ] + }, + { + "cell_type": "markdown", + "id": "ed62ae2d-87f5-4e5a-8dd8-57a90254dbd2", + "metadata": {}, + "source": [ + "We now prepare the initial population.\n", + "When we are dealing with non-numeric solutions, a population is represented via `evotorch.tools.ObjectArray`, instead of `torch.Tensor`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f80777a0-8b9b-492f-8306-67f8913a3596", + "metadata": {}, + "outputs": [], + "source": [ + "population = problem.generate_values(popsize)\n", + "population" + ] + }, + { + "cell_type": "markdown", + "id": "b513cd3d-01b5-4b3f-ad2d-74631f443a65", + "metadata": {}, + "source": [ + "Evaluate the fitnesses of the solutions within the initial population:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6d075e7-27be-4019-a783-d0c024fa6cd0", + "metadata": {}, + "outputs": [], + "source": [ + "evals = f(population)\n", + "evals" + ] + }, + { + "cell_type": "markdown", + "id": "128ab3ce-30f8-43b3-b966-1a7c0a5099df", + "metadata": {}, + "source": [ + "Main loop of the evolutionary search:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c78a7d4f-9c12-4bd5-977e-668acb44af79", + "metadata": {}, + "outputs": [], + "source": [ + "for generation in range(1, 1 + num_generations):\n", + " t_begin = datetime.now()\n", + "\n", + " # Apply tournament selection on the population\n", + " parent1_indices, parent2_indices = func_ops.tournament(\n", + " population,\n", + " evals,\n", + " tournament_size=tournament_size,\n", + " num_tournaments=popsize,\n", + " split_results=True,\n", + " return_indices=True,\n", + " objective_sense=objective_sense,\n", + " )\n", + "\n", + " # The results of the tournament selection are stored within the integer\n", + " # tensors `parent1_indices` and `parent2_indices`.\n", + " # The pairs of solutions for the cross-over operator are:\n", + " # - `population[parent1_indices[0]]` and `population[parent2_indices[0]]`,\n", + " # - `population[parent1_indices[1]]` and `population[parent2_indices[1]]`,\n", + " # - `population[parent1_indices[2]]` and `population[parent2_indices[2]]`,\n", + " # - and so on...\n", + " num_pairs = len(parent1_indices)\n", + " children = []\n", + " for i in range(num_pairs):\n", + " parent1_index = int(parent1_indices[i])\n", + " parent2_index = int(parent2_indices[i])\n", + " child1, child2 = cross_over(population[parent1_index], population[parent2_index])\n", + " child1 = mutate(child1)\n", + " child2 = mutate(child2)\n", + " children.extend([child1, child2])\n", + "\n", + " # With the help of the function `evotorch.tools.make_tensor(...)`,\n", + " # we convert the list of child solutions to an ObjectArray, so that\n", + " # `children` can be treated as a population of solutions.\n", + " children = make_tensor(children, dtype=object)\n", + "\n", + " # Combine the original population with the population of children,\n", + " # forming an extended population.\n", + " extended_population = func_ops.combine(population, children)\n", + "\n", + " # Evaluate all the solutions within the extended population.\n", + " extended_evals = f(extended_population)\n", + "\n", + " # Take the best `popsize` number of solutions from the extended population.\n", + " population, evals = func_ops.take_best(\n", + " extended_population, extended_evals, popsize, objective_sense=objective_sense\n", + " )\n", + "\n", + " t_end = datetime.now()\n", + " time_taken = (t_end - t_begin).total_seconds()\n", + "\n", + " # Report the status of the evolutionary search.\n", + " print(\n", + " \"Generation:\", generation,\n", + " \" Mean eval:\", number_to_str(evals.mean()),\n", + " \" Pop best:\", number_to_str(evals.max()),\n", + " \" Time:\", number_to_str(time_taken)\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "57a2765d-df97-4a48-a593-323cfff07116", + "metadata": {}, + "source": [ + "Take the index of the best solution within the last population:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a6a6376-f296-498d-8ff0-d039d6c67099", + "metadata": {}, + "outputs": [], + "source": [ + "best_index = torch.argmax(evals)\n", + "best_index" + ] + }, + { + "cell_type": "markdown", + "id": "76691a35-b997-49d5-9585-0ecf850608fc", + "metadata": {}, + "source": [ + "Take the best solution within the last population:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2f773b5-3b46-4046-828a-7ca60fd4be5a", + "metadata": {}, + "outputs": [], + "source": [ + "best_params = population[best_index]\n", + "best_params" + ] + }, + { + "cell_type": "markdown", + "id": "f63d773f-3059-462f-808a-fddf4ae721dc", + "metadata": {}, + "source": [ + "Visualize the gait of the population's best solution:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe415cc7-b08b-4ffa-9548-78c3e890bfff", + "metadata": {}, + "outputs": [], + "source": [ + "problem.visualize(best_params)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/evotorch/core.py b/src/evotorch/core.py index 5acb6194..479ab4f2 100644 --- a/src/evotorch/core.py +++ b/src/evotorch/core.py @@ -5246,4 +5246,7 @@ def __call__(self, values: ObjectArray) -> torch.Tensor: "The positional argument `values` was expected as an `ObjectArray`." f" However, an object of this type was encountered: {type(values)}." ) - return self._prepare_evaluated_solution_batch(values).evals.as_subclass(torch.Tensor) + result = self._prepare_evaluated_solution_batch(values).evals.as_subclass(torch.Tensor) + if len(self._problem.senses) == 1: + result = result.reshape(-1) + return result From 6644480c74ffbf2433c2a220e23df016b60f7f56 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:25:37 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../notebooks/Functional_API/func_rl_ga.ipynb | 82 +++++++++---------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/examples/notebooks/Functional_API/func_rl_ga.ipynb b/examples/notebooks/Functional_API/func_rl_ga.ipynb index 73fd3265..033cff4e 100644 --- a/examples/notebooks/Functional_API/func_rl_ga.ipynb +++ b/examples/notebooks/Functional_API/func_rl_ga.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "f9dc9f5e-802f-4305-9919-bd2202f8a63b", + "id": "0", "metadata": {}, "source": [ "# Evolving objects using the functional operators API of EvoTorch\n", @@ -44,7 +44,7 @@ }, { "cell_type": "markdown", - "id": "8df623bd-3cc6-44d8-888f-fe16460ffa9a", + "id": "1", "metadata": {}, "source": [ "## Summary of the evolutionary algorithm\n", @@ -69,7 +69,7 @@ }, { "cell_type": "markdown", - "id": "510e620e-a815-48c8-901b-7eeba7781786", + "id": "2", "metadata": {}, "source": [ "## Implementation" @@ -78,7 +78,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a2b0b8b-7bc6-49e3-ac27-21aba81cb2df", + "id": "3", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "e0b522cd-4ce1-4fc5-a0a7-f066bb1a2a04", + "id": "4", "metadata": {}, "source": [ "The function below takes a series of seeds, and makes a tensor of real numbers out of them.\n", @@ -112,7 +112,7 @@ { "cell_type": "code", "execution_count": null, - "id": "210a3488-9370-4a87-bf8f-46f059379243", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -157,7 +157,7 @@ }, { "cell_type": "markdown", - "id": "7cdb5a99-7372-43a3-a7a1-07ec892c6192", + "id": "6", "metadata": {}, "source": [ "Helper function to generate a random seed integer:" @@ -166,7 +166,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe325334-8f42-41d8-98bb-a8db25980da4", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -176,7 +176,7 @@ }, { "cell_type": "markdown", - "id": "5808a486-5178-4409-9ba2-b0a34bf92dff", + "id": "8", "metadata": {}, "source": [ "**Observation normalization.**\n", @@ -187,7 +187,7 @@ { "cell_type": "code", "execution_count": null, - "id": "06b496a8-9874-44d0-859f-b67086f62474", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -212,7 +212,7 @@ { "cell_type": "code", "execution_count": null, - "id": "abf2e4c9-abe7-4a8b-b844-a07729754458", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -307,7 +307,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3fde9767-b959-450b-896b-6a5e129f329f", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -342,7 +342,7 @@ }, { "cell_type": "markdown", - "id": "8f782f05-93b6-4d80-9679-e3b6028428e3", + "id": "12", "metadata": {}, "source": [ "**Problem definition.**\n", @@ -353,7 +353,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c8216b77-0f64-4f9d-8a3d-44cfe3deaec5", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -526,7 +526,7 @@ }, { "cell_type": "markdown", - "id": "394066c3-6eeb-44d1-baa0-fcae4e2101ef", + "id": "14", "metadata": {}, "source": [ "We now define our mutation and cross-over operators, via the functions `mutate` and `cross_over`.\n", @@ -536,7 +536,7 @@ { "cell_type": "code", "execution_count": null, - "id": "676714c9-6288-459f-aeef-411849a2af2b", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -555,7 +555,7 @@ { "cell_type": "code", "execution_count": null, - "id": "62cf3114-83b0-48d3-be2e-3dd8387d8c32", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -580,7 +580,7 @@ }, { "cell_type": "markdown", - "id": "51178388-1db5-4864-8012-607558f7a151", + "id": "17", "metadata": {}, "source": [ "ID of the considered reinforcement learning task:" @@ -589,7 +589,7 @@ { "cell_type": "code", "execution_count": null, - "id": "68f3b4f1-9120-4d0c-a314-cff8a6a1ad93", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -598,7 +598,7 @@ }, { "cell_type": "markdown", - "id": "f2f43b33-d6a1-4fd0-97bb-750a8e3c6667", + "id": "19", "metadata": {}, "source": [ "Generate or load observation data for the considered reinforcement learning environment:" @@ -607,7 +607,7 @@ { "cell_type": "code", "execution_count": null, - "id": "35d310cd-fca4-4887-8dd8-0b586efd7e48", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -617,7 +617,7 @@ }, { "cell_type": "markdown", - "id": "6e6ef3aa-1f8f-4ee1-807c-0770dbb0312f", + "id": "21", "metadata": {}, "source": [ "Instantiate the problem object:" @@ -626,7 +626,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c89d035f-c24f-45cb-ae3f-af2ffd49377d", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -644,7 +644,7 @@ }, { "cell_type": "markdown", - "id": "ff43c026-6588-4b38-841c-677bd4faf0e5", + "id": "23", "metadata": {}, "source": [ "Out of the instantiated problem object, we make a callable evaluator named `f`.\n", @@ -654,7 +654,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bc9c6588-535a-4e3e-90f4-e994a730068b", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -664,7 +664,7 @@ }, { "cell_type": "markdown", - "id": "78b7f831-1202-44d3-88de-dbad81f1d7df", + "id": "25", "metadata": {}, "source": [ "Helper function for converting a real number to a string.\n", @@ -674,7 +674,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fbd919a7-ca9d-4c13-90ef-d6cddb23faef", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -684,7 +684,7 @@ }, { "cell_type": "markdown", - "id": "5f1cc2b1-7d2d-44d7-8741-e523c2a01050", + "id": "27", "metadata": {}, "source": [ "Hyperparameters and constants for the evolutionary algorithm:" @@ -693,7 +693,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b7c2b6eb-d1d4-4b54-a3db-7125dcb9725c", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -705,7 +705,7 @@ }, { "cell_type": "markdown", - "id": "ed62ae2d-87f5-4e5a-8dd8-57a90254dbd2", + "id": "29", "metadata": {}, "source": [ "We now prepare the initial population.\n", @@ -715,7 +715,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f80777a0-8b9b-492f-8306-67f8913a3596", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -725,7 +725,7 @@ }, { "cell_type": "markdown", - "id": "b513cd3d-01b5-4b3f-ad2d-74631f443a65", + "id": "31", "metadata": {}, "source": [ "Evaluate the fitnesses of the solutions within the initial population:" @@ -734,7 +734,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a6d075e7-27be-4019-a783-d0c024fa6cd0", + "id": "32", "metadata": {}, "outputs": [], "source": [ @@ -744,7 +744,7 @@ }, { "cell_type": "markdown", - "id": "128ab3ce-30f8-43b3-b966-1a7c0a5099df", + "id": "33", "metadata": {}, "source": [ "Main loop of the evolutionary search:" @@ -753,7 +753,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c78a7d4f-9c12-4bd5-977e-668acb44af79", + "id": "34", "metadata": {}, "outputs": [], "source": [ @@ -819,7 +819,7 @@ }, { "cell_type": "markdown", - "id": "57a2765d-df97-4a48-a593-323cfff07116", + "id": "35", "metadata": {}, "source": [ "Take the index of the best solution within the last population:" @@ -828,7 +828,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a6a6376-f296-498d-8ff0-d039d6c67099", + "id": "36", "metadata": {}, "outputs": [], "source": [ @@ -838,7 +838,7 @@ }, { "cell_type": "markdown", - "id": "76691a35-b997-49d5-9585-0ecf850608fc", + "id": "37", "metadata": {}, "source": [ "Take the best solution within the last population:" @@ -847,7 +847,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b2f773b5-3b46-4046-828a-7ca60fd4be5a", + "id": "38", "metadata": {}, "outputs": [], "source": [ @@ -857,7 +857,7 @@ }, { "cell_type": "markdown", - "id": "f63d773f-3059-462f-808a-fddf4ae721dc", + "id": "39", "metadata": {}, "source": [ "Visualize the gait of the population's best solution:" @@ -866,7 +866,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fe415cc7-b08b-4ffa-9548-78c3e890bfff", + "id": "40", "metadata": {}, "outputs": [], "source": [