Example for evolutionary regression with parametrized nodes

Example demonstrating the use of Cartesian genetic programming for a regression task that requires fine tuning of constants in parametrized nodes. This is achieved by introducing a new node, “ParametrizedAdd” which produces a scaled and shifted version of the sum of its inputs.

# The docopt str is added explicitly to ensure compatibility with
# sphinx-gallery.
docopt_str = """
  Usage:
    example_parametrized_nodes.py [--max-generations=<N>]

  Options:
    -h --help
    --max-generations=<N>  Maximum number of generations [default: 500]
"""

import functools
import math

import matplotlib.pyplot as plt
import numpy as np
import scipy.constants
import torch
from docopt import docopt

import cgp

args = docopt(docopt_str)

We first define a new node that adds the values of its two inputs then scales and finally shifts the result. The scale (“w”) and shift factors (“b”) are parameters that are adapted by local search. We need to define the arity of the node, callables for the initial values for the parameters and the operation of the node as a string. In this string parameters are enclosed in angle brackets, inputs are denoted by “x_i” with i representing their corresponding index.

class ParametrizedAdd(cgp.OperatorNode):
    """A node that adds its two inputs.

    The result of addition is scaled by w and shifted by b. Both these
    parameters can be adapted via local search are passed on from
    parents to their offspring.

    """

    _arity = 2
    _initial_values = {"<w>": lambda: 1.0, "<b>": lambda: 0.0}
    _def_output = "<w> * (x_0 + x_1) + <b>"

We define a target function which contains numerical constants that are not available as constants for the search and need to be found by local search on parameterized nodes.

def f_target(x):
    return math.pi * (x[:, 0] + x[:, 1]) + math.e

Then we define a differentiable(!) inner objective function for the evolution. This function accepts a torch class as a parameter. It returns the mean-squared error between the output of the forward pass of this class and the target function evaluated on a set of random points. This inner objective is then used by actual objective function to determine the fitness of the individual.

def inner_objective(f, seed):
    torch.manual_seed(seed)
    batch_size = 500
    x = torch.DoubleTensor(batch_size, 2).uniform_(-5, 5)
    y = f(x)
    return torch.nn.MSELoss()(f_target(x), y[:, 0])


def objective(individual, seed):
    if not individual.fitness_is_None():
        return individual

    f = individual.to_torch()
    loss = inner_objective(f, seed)
    individual.fitness = -loss.item()

    return individual

Next, we define the parameters for the genome of individuals, the evolutionary algorithm, and the local search. Note that we add the custom node defined above as a primitive.

seed = 1234

genome_params = {
    "n_inputs": 2,
    "n_columns": 5,
    "primitives": (ParametrizedAdd, cgp.Add, cgp.Sub, cgp.Mul),
}

evolve_params = {"max_generations": int(args["--max-generations"]), "termination_fitness": 0.0}

local_search_params = {"lr": 1e-3, "gradient_steps": 9}

We then create a Population instance and instantiate the local search and evolutionary algorithm.

pop = cgp.Population(genome_params=genome_params)

local_search = functools.partial(
    cgp.local_search.gradient_based,
    objective=functools.partial(inner_objective, seed=seed),
    **local_search_params,
)

ea = cgp.ea.MuPlusLambda(local_search=local_search)

We define a recording callback closure for bookkeeping of the progression of the evolution.

history = {}
history["fitness_champion"] = []
history["expr_champion"] = []


def recording_callback(pop):
    history["fitness_champion"].append(pop.champion.fitness)
    history["expr_champion"].append(pop.champion.to_sympy())

We fix the seed for the objective function to make sure results are comparable across individuals and, finally, we call the evolve method to perform the evolutionary search.

obj = functools.partial(objective, seed=seed)

pop = cgp.evolve(obj, pop, ea, **evolve_params, print_progress=True, callback=recording_callback)

Out:

[2/500] max fitness: -29.296308252293986
[3/500] max fitness: -23.70581555633807
[4/500] max fitness: -22.373383329751597
[5/500] max fitness: -21.929178550932026
[6/500] max fitness: -21.674922439308162
[7/500] max fitness: -21.465755988900582
[8/500] max fitness: -21.271468469608145
[9/500] max fitness: -21.085605715652054
[10/500] max fitness: -20.90664166714249
[11/500] max fitness: -20.73407781232532
[12/500] max fitness: -20.567634950821294
[13/500] max fitness: -20.40708553905934
[14/500] max fitness: -20.252218740296733
[15/500] max fitness: -20.102832972172592
[16/500] max fitness: -18.751269599871726
[17/500] max fitness: -12.063681252671143
[18/500] max fitness: -8.35619033157912
[19/500] max fitness: -6.273925762349376
[20/500] max fitness: -5.078967613213775
[21/500] max fitness: -4.369379077028278
[22/500] max fitness: -3.926196365875168
[23/500] max fitness: -3.6301021293612643
[24/500] max fitness: -3.4160472387282828
[25/500] max fitness: -3.2485385547455725
[26/500] max fitness: -3.1081810126350837
[27/500] max fitness: -2.984348603716845
[28/500] max fitness: -2.8711930442554285
[29/500] max fitness: -2.7654700901548672
[30/500] max fitness: -2.6653557109016166
[31/500] max fitness: -2.569801325334066
[32/500] max fitness: -2.478182603177361
[33/500] max fitness: -2.3901081396213333
[34/500] max fitness: -2.3053151962600475
[35/500] max fitness: -2.2236128591296844
[36/500] max fitness: -2.14485102150343
[37/500] max fitness: -2.0689034325716227
[38/500] max fitness: -1.9956584082645896
[39/500] max fitness: -1.9250137167829566
[40/500] max fitness: -1.8568737395701236
[41/500] max fitness: -1.791147873352144
[42/500] max fitness: -1.7277496098762388
[43/500] max fitness: -1.666595986482696
[44/500] max fitness: -1.6076072403343167
[45/500] max fitness: -1.550706575202033
[46/500] max fitness: -1.4958199911362748
[47/500] max fitness: -1.4428761499181828
[48/500] max fitness: -1.3918062614750406
[49/500] max fitness: -1.3425439831392456
[50/500] max fitness: -1.2950253272779682
[51/500] max fitness: -1.2491885748088931
[52/500] max fitness: -1.204974193202052
[53/500] max fitness: -1.1623247581599876
[54/500] max fitness: -1.1211848784925864
[55/500] max fitness: -1.0815011238811154
[56/500] max fitness: -1.0432219553243258
[57/500] max fitness: -1.0062976581146288
[58/500] max fitness: -0.9706802772237402
[59/500] max fitness: -0.9363235549955334
[60/500] max fitness: -0.903182871055257
[61/500] max fitness: -0.8712151843516398
[62/500] max fitness: -0.8403789772536476
[63/500] max fitness: -0.8106342016276896
[64/500] max fitness: -0.781942226824364
[65/500] max fitness: -0.7542657895067096
[66/500] max fitness: -0.7275689452545611
[67/500] max fitness: -0.7018170218819759
[68/500] max fitness: -0.6769765744070783
[69/500] max fitness: -0.6530153416157455
[70/500] max fitness: -0.6299022041627423
[71/500] max fitness: -0.607607144155846
[72/500] max fitness: -0.5861012061704887
[73/500] max fitness: -0.5653564596442547
[74/500] max fitness: -0.5453459626024281
[75/500] max fitness: -0.5260437266674328
[76/500] max fitness: -0.5074246833067669
[77/500] max fitness: -0.4894646512755637
[78/500] max fitness: -0.47214030521150735
[79/500] max fitness: -0.45542914534131707
[80/500] max fitness: -0.4393094682594517
[81/500] max fitness: -0.42376033874109076
[82/500] max fitness: -0.4087615625527799
[83/500] max fitness: -0.39429366022542356
[84/500] max fitness: -0.38033784175557755
[85/500] max fitness: -0.3668759822021694
[86/500] max fitness: -0.35389059814696366
[87/500] max fitness: -0.34136482498819537
[88/500] max fitness: -0.3292823950378824
[89/500] max fitness: -0.31762761639437664
[90/500] max fitness: -0.3063853525627044
[91/500] max fitness: -0.295541002796238
[92/500] max fitness: -0.2850804831341608
[93/500] max fitness: -0.2749902081101045
[94/500] max fitness: -0.26525707310819835
[95/500] max fitness: -0.25586843734361525
[96/500] max fitness: -0.2468121074455151
[97/500] max fitness: -0.23807632162106007
[98/500] max fitness: -0.2296497343799347
[99/500] max fitness: -0.2215214017995363
[100/500] max fitness: -0.21368076731169583
[101/500] max fitness: -0.20611764799246948
[102/500] max fitness: -0.19882222133719454
[103/500] max fitness: -0.191785012503639
[104/500] max fitness: -0.18499688200666992
[105/500] max fitness: -0.17844901384846282
[106/500] max fitness: -0.17213290406884243
[107/500] max fitness: -0.16604034970086487
[108/500] max fitness: -0.1601634381173252
[109/500] max fitness: -0.1544945367543311
[110/500] max fitness: -0.149026283198609
[111/500] max fitness: -0.14375157562566274
[112/500] max fitness: -0.1386635635763712
[113/500] max fitness: -0.13375563906004143
[114/500] max fitness: -0.12902142797236407
[115/500] max fitness: -0.12445478181712724
[116/500] max fitness: -0.12004976972093682
[117/500] max fitness: -0.11580067073056924
[118/500] max fitness: -0.11170196638295596
[119/500] max fitness: -0.10774833353815111
[120/500] max fitness: -0.10393463746597206
[121/500] max fitness: -0.10025592517733185
[122/500] max fitness: -0.09670741899160841
[123/500] max fitness: -0.09328451033169581
[124/500] max fitness: -0.08998275373866967
[125/500] max fitness: -0.08679786109830687
[126/500] max fitness: -0.08372569607194964
[127/500] max fitness: -0.08076226872449109
[128/500] max fitness: -0.0779037303424953
[129/500] max fitness: -0.07514636843573262
[130/500] max fitness: -0.0724866019156276
[131/500] max fitness: -0.06992097644436795
[132/500] max fitness: -0.06744615994862688
[133/500] max fitness: -0.06505893829207501
[134/500] max fitness: -0.06275621110106212
[135/500] max fitness: -0.042233661573409534
[136/500] max fitness: -0.028438783004635728
[137/500] max fitness: -0.01915161947762114
[138/500] max fitness: -0.012897972328620583
[139/500] max fitness: -0.008686600682183142
[140/500] max fitness: -0.005850396533949138
[141/500] max fitness: -0.00394025643361544
[142/500] max fitness: -0.0026537838829984875
[143/500] max fitness: -0.0017873417504784329
[144/500] max fitness: -0.0012037890110839311
[145/500] max fitness: -0.000810763251125896
[146/500] max fitness: -0.0005460584431755792
[147/500] max fitness: -0.00036777846011057635
[148/500] max fitness: -0.0002477059376820576
[149/500] max fitness: -0.00016683628844257838
[150/500] max fitness: -0.00011236983413594339
[151/500] max fitness: -7.568596440373674e-05
[152/500] max fitness: -5.0978700115355825e-05
[153/500] max fitness: -3.433772571639737e-05
[154/500] max fitness: -2.31294547452853e-05
[155/500] max fitness: -1.5580173680538555e-05
[156/500] max fitness: -1.0495288703189141e-05
[157/500] max fitness: -7.070239165567859e-06
[158/500] max fitness: -4.763149007448779e-06
[159/500] max fitness: -3.2090586318525766e-06
[160/500] max fitness: -2.1621609904425323e-06
[161/500] max fitness: -1.456898238819804e-06
[162/500] max fitness: -9.81760750602782e-07
[163/500] max fitness: -6.616409867954462e-07
[164/500] max fitness: -4.459489292411943e-07
[165/500] max fitness: -3.006078407183301e-07
[166/500] max fitness: -2.0266337221567036e-07
[167/500] max fitness: -1.3665273091677727e-07
[168/500] max fitness: -9.215923685739751e-08
[169/500] max fitness: -6.216523695407847e-08
[170/500] max fitness: -4.19427169481387e-08
[171/500] max fitness: -2.8306060291748006e-08
[172/500] max fitness: -1.910872169120322e-08
[173/500] max fitness: -1.2904186004523732e-08
[174/500] max fitness: -8.717581627333424e-09
[175/500] max fitness: -5.891827126370599e-09
[176/500] max fitness: -3.983982478385584e-09
[177/500] max fitness: -2.6954185970864777e-09
[178/500] max fitness: -1.8247680425765078e-09
[179/500] max fitness: -1.2362226148000097e-09
[180/500] max fitness: -8.381701235979647e-10
[181/500] max fitness: -5.68797055206648e-10
[182/500] max fitness: -3.863847365059541e-10
[183/500] max fitness: -2.6276813635810336e-10
[184/500] max fitness: -1.7892593959647881e-10
[185/500] max fitness: -1.220070102788369e-10
[186/500] max fitness: -8.332502815236628e-11

After the evolutionary search has ended, we print the expression with the highest fitness and plot the progression of the search.

print(f"Final expression {pop.champion.to_sympy()} with fitness {pop.champion.fitness}")

print("Best performing expression per generation (for fitness increase > 0.5):")
old_fitness = -np.inf
for i, (fitness, expr) in enumerate(zip(history["fitness_champion"], history["expr_champion"])):
    delta_fitness = fitness - old_fitness
    if delta_fitness > 0.5:
        print(f"{i:3d}: {fitness}, {expr}")
        old_fitness = fitness
print(f"{i:3d}: {fitness}, {expr}")

width = 9.0

fig = plt.figure(figsize=(width, width / scipy.constants.golden))

ax_fitness = fig.add_subplot(111)
ax_fitness.set_xlabel("Generation")
ax_fitness.set_ylabel("Fitness")
ax_fitness.set_yscale("symlog")

ax_fitness.axhline(0.0, color="k")
ax_fitness.plot(history["fitness_champion"], lw=2)

plt.savefig("example_parametrized_nodes.pdf", dpi=300)
example parametrized nodes

Out:

Final expression 3.141591581530465*x_0 + 3.1415928389502914*x_1 + 2.718273160795456 with fitness -8.332502815236628e-11
Best performing expression per generation (for fitness increase > 0.5):
  0: -57.02725920814872, 2.0*x_0 + 1.0*x_1
  1: -29.296308252293986, 1.480962760718721*x_0 + 2.961925521437442*x_1 + 0.04710656595408981
  2: -23.70581555633807, 1.6998291931463108*x_0 + 3.3996583862926216*x_1 + 0.09298198289560022
  3: -22.373383329751597, 1.7994065154975194*x_0 + 3.5988130309950388*x_1 + 0.13786076395943922
  5: -21.674922439308162, 1.8652668897025921*x_0 + 3.7305337794051842*x_1 + 0.22503177123698184
  8: -21.085605715652054, 1.880691620715323*x_0 + 3.761383241430646*x_1 + 0.3499164450476753
 11: -20.567634950821294, 1.8819819923768122*x_0 + 3.7639639847536244*x_1 + 0.46820991723300653
 15: -18.751269599871726, 2.2120079552914897*x_0 + 2.2120079552914897*x_1 + 0.6184430991872557
 16: -12.063681252671143, 2.455578327740484*x_0 + 2.455578327740484*x_1 + 0.6559303097239474
 17: -8.35619033157912, 2.6353233187093275*x_0 + 2.6353233187093275*x_1 + 0.6927504598420281
 18: -6.273925762349376, 2.767967871678684*x_0 + 2.767967871678684*x_1 + 0.7289148488062044
 19: -5.078967613213775, 2.865854287549723*x_0 + 2.865854287549723*x_1 + 0.7644347345395887
 20: -4.369379077028278, 2.9380906623934124*x_0 + 2.9380906623934124*x_1 + 0.7993212923236193
 22: -3.6301021293612643, 3.0307375595447175*x_0 + 3.0307375595447175*x_1 + 0.8672385433370949
 25: -3.1081810126350837, 3.0970027480565787*x_0 + 3.0970027480565787*x_1 + 0.964636439195826
 30: -2.569801325334066, 3.131786035064894*x_0 + 3.131786035064894*x_1 + 1.1157138304969136
 36: -2.0689034325716227, 3.1399624915439723*x_0 + 3.1399624915439723*x_1 + 1.2799278409193942
 44: -1.550706575202033, 3.141405328696915*x_0 + 3.141405328696915*x_1 + 1.4730082526554913
 55: -1.0432219553243258, 3.1415476398125395*x_0 + 3.1415476398125395*x_1 + 1.696899442652304
 74: -0.5260437266674328, 3.1415641478877396*x_0 + 3.1415641478877396*x_1 + 1.9929931007668715
136: -0.01915161947762114, 3.1412926038084588*x_0 + 3.1427315735574535*x_1 + 2.5798427188837497
185: -8.332502815236628e-11, 3.141591581530465*x_0 + 3.1415928389502914*x_1 + 2.718273160795456

Total running time of the script: ( 0 minutes 19.194 seconds)

Gallery generated by Sphinx-Gallery