Optimization Examples

Fitting cuboid polarization

import jax
import jax.numpy as jnp
import magpylib_jax as mpj

obs = jnp.array([[0.2, 0.1, 0.4], [0.5, 0.0, 0.7]])
target = jnp.array([[2.0e-4, 0.0, 3.0e-4], [1.0e-4, 0.0, 2.0e-4]])


def loss_fn(pol):
    src = mpj.magnet.Cuboid(dimension=(1.0, 0.8, 1.2), polarization=pol)
    pred = src.getB(obs)
    return jnp.mean((pred - target) ** 2)

pol = jnp.array([0.05, -0.02, 0.08])
for _ in range(50):
    pol = pol - 1e-1 * jax.grad(loss_fn)(pol)

Fitting multiple source parameters

import jax
import jax.numpy as jnp
import magpylib_jax as mpj

obs = jnp.array(
    [
        [0.2, 0.1, 0.4],
        [0.5, 0.0, 0.7],
        [-0.1, 0.3, 0.2],
        [0.3, -0.2, 0.6],
    ]
)
target = jnp.array(
    [
        [2.0e-4, 0.0, 3.0e-4],
        [1.0e-4, 0.0, 2.0e-4],
        [1.5e-4, 0.5e-4, 2.2e-4],
        [0.8e-4, -0.2e-4, 1.7e-4],
    ]
)


def loss_fn(params):
    pol1 = params[0:3]
    pos1 = params[3:6]
    pol2 = params[6:9]
    pos2 = params[9:12]

    src1 = mpj.magnet.Cuboid(dimension=(1.0, 0.8, 1.2), polarization=pol1, position=pos1)
    src2 = mpj.magnet.Cuboid(dimension=(0.6, 0.6, 0.6), polarization=pol2, position=pos2)
    pred = mpj.Collection(src1, src2).getB(obs)
    return jnp.mean((pred - target) ** 2)

params = jnp.array([0.05, -0.02, 0.08, 0.0, 0.0, 0.0, 0.03, 0.01, 0.04, 0.2, 0.1, -0.1])
for _ in range(80):
    params = params - 5e-2 * jax.grad(loss_fn)(params)

Practical note

For large optimization loops:

  • prefer x64,

  • keep observer layouts static when possible,

  • reuse object graphs if the optimization variables can be isolated cleanly,

  • profile both compile time and steady-state runtime.