"""JAX-native differentiable magnetic field kernels."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from jax import lax
from magpylib_jax._types import ArrayLike
from magpylib_jax.constants import MU0
from magpylib_jax.core.elliptic import cel, ellipe, ellipk, ellippi
from magpylib_jax.core.geometry import cart_to_cyl, cyl_field_to_cart, ensure_observers
_FOUR_PI = 4.0 * jnp.pi
def _broadcast_vector(vector: jnp.ndarray, target_shape: tuple[int, ...]) -> jnp.ndarray:
if vector.ndim == 1:
return jnp.broadcast_to(vector[None, :], target_shape)
return jnp.broadcast_to(vector, target_shape)
def _cel_iter(
qc: jnp.ndarray,
p: jnp.ndarray,
g: jnp.ndarray,
cc: jnp.ndarray,
ss: jnp.ndarray,
em: jnp.ndarray,
kk: jnp.ndarray,
) -> jnp.ndarray:
"""Vectorized Bulirsch CEL iteration in JAX."""
def body_fn(_: int, state: tuple[jnp.ndarray, ...]) -> tuple[jnp.ndarray, ...]:
qc_, p_, g_, cc_, ss_, em_, kk_ = state
mask = jnp.abs(g_ - qc_) >= qc_ * 1e-8
qc_new = 2.0 * jnp.sqrt(kk_)
kk_new = qc_new * em_
f = cc_
cc_new = cc_ + ss_ / p_
g_new = kk_new / p_
ss_new = 2.0 * (ss_ + f * g_new)
p_new = p_ + g_new
g_store = em_
em_new = em_ + qc_new
qc_out = jnp.where(mask, qc_new, qc_)
p_out = jnp.where(mask, p_new, p_)
g_out = jnp.where(mask, g_store, g_)
cc_out = jnp.where(mask, cc_new, cc_)
ss_out = jnp.where(mask, ss_new, ss_)
em_out = jnp.where(mask, em_new, em_)
kk_out = jnp.where(mask, kk_new, kk_)
return qc_out, p_out, g_out, cc_out, ss_out, em_out, kk_out
qc, p, _, cc, ss, em, _ = lax.fori_loop(0, 32, body_fn, (qc, p, g, cc, ss, em, kk))
return 0.5 * jnp.pi * (ss + cc * em) / (em * (em + p))
[docs]
@jax.jit
def dipole_hfield(observers: ArrayLike, moments: ArrayLike) -> jnp.ndarray:
"""H-field of dipole moments located at the origin."""
obs = ensure_observers(observers)
mom = _broadcast_vector(jnp.asarray(moments, dtype=jnp.float64), obs.shape)
r2 = jnp.sum(obs * obs, axis=-1)
inv_r3 = jnp.where(r2 > 0.0, r2 ** (-1.5), jnp.inf)
inv_r5 = jnp.where(r2 > 0.0, r2 ** (-2.5), jnp.inf)
mdotr = jnp.sum(mom * obs, axis=-1)
h = (3.0 * mdotr[:, None] * obs * inv_r5[:, None] - mom * inv_r3[:, None]) / _FOUR_PI
origin_mask = r2 == 0.0
h_origin = jnp.where(mom == 0.0, 0.0, jnp.sign(mom) * jnp.inf)
return jnp.where(origin_mask[:, None], h_origin, h)
[docs]
@jax.jit
def current_circle_hfield(
observers: ArrayLike,
diameter: ArrayLike,
current: ArrayLike,
*,
singular_tol: float = 1e-15,
) -> jnp.ndarray:
"""H-field of circular current loops centered at the origin in the xy plane."""
obs = ensure_observers(observers)
r, phi, z = cart_to_cyl(obs)
radius = jnp.abs(jnp.asarray(diameter, dtype=jnp.float64) / 2.0)
cur = jnp.asarray(current, dtype=jnp.float64)
radius = jnp.broadcast_to(radius, r.shape)
cur = jnp.broadcast_to(cur, r.shape)
mask_zero_radius = radius == 0.0
mask_singular = jnp.logical_and(jnp.abs(r - radius) < singular_tol * radius, z == 0.0)
mask_general = jnp.logical_not(jnp.logical_or(mask_zero_radius, mask_singular))
safe_radius = jnp.where(mask_general, radius, 1.0)
rr = r / safe_radius
zz = z / safe_radius
z2 = zz * zz
x0 = z2 + (rr + 1.0) ** 2
k2 = 4.0 * rr / x0
q2 = (z2 + (rr - 1.0) ** 2) / x0
q2 = jnp.where(mask_general, q2, 1.0)
q = jnp.sqrt(q2)
p = 1.0 + q
pf = cur / (_FOUR_PI * safe_radius * jnp.sqrt(x0) * q2)
cc = k2 * 4.0 * zz / x0
ss = 2.0 * cc * q / p
hr = pf * _cel_iter(q, p, jnp.ones_like(q), cc, ss, p, q)
k4 = k2 * k2
cc = k4 - (q2 + 1.0) * (4.0 / x0)
ss = 2.0 * q * (k4 / p - (4.0 / x0) * p)
hz = -pf * _cel_iter(q, p, jnp.ones_like(q), cc, ss, p, q)
hr = jnp.where(mask_general, hr, 0.0)
hz = jnp.where(mask_general, hz, 0.0)
return cyl_field_to_cart(phi, hr, hz)
[docs]
@jax.jit
def magnet_cuboid_bfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""B-field of homogeneously polarized cuboids centered at the origin."""
obs = ensure_observers(observers)
dim = _broadcast_vector(jnp.asarray(dimensions, dtype=jnp.float64), obs.shape)
pol = _broadcast_vector(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape)
pol_x, pol_y, pol_z = pol.T
a, b, c = (dim / 2.0).T
x, y, z = obs.T
maskx = x < 0.0
masky = y > 0.0
maskz = z > 0.0
x = jnp.where(maskx, -x, x)
y = jnp.where(masky, -y, y)
z = jnp.where(maskz, -z, z)
qsigns = jnp.ones((obs.shape[0], 3, 3), dtype=jnp.float64)
qs_flipx = jnp.array([[1, -1, -1], [-1, 1, 1], [-1, 1, 1]], dtype=jnp.float64)
qs_flipy = jnp.array([[1, -1, 1], [-1, 1, -1], [1, -1, 1]], dtype=jnp.float64)
qs_flipz = jnp.array([[1, 1, -1], [1, 1, -1], [-1, -1, 1]], dtype=jnp.float64)
qsigns = qsigns * jnp.where(maskx[:, None, None], qs_flipx, 1.0)
qsigns = qsigns * jnp.where(masky[:, None, None], qs_flipy, 1.0)
qsigns = qsigns * jnp.where(maskz[:, None, None], qs_flipz, 1.0)
xma, xpa = x - a, x + a
ymb, ypb = y - b, y + b
zmc, zpc = z - c, z + c
xma2, xpa2 = xma * xma, xpa * xpa
ymb2, ypb2 = ymb * ymb, ypb * ypb
zmc2, zpc2 = zmc * zmc, zpc * zpc
mmm = jnp.sqrt(xma2 + ymb2 + zmc2)
pmp = jnp.sqrt(xpa2 + ymb2 + zpc2)
pmm = jnp.sqrt(xpa2 + ymb2 + zmc2)
mmp = jnp.sqrt(xma2 + ymb2 + zpc2)
mpm = jnp.sqrt(xma2 + ypb2 + zmc2)
ppp = jnp.sqrt(xpa2 + ypb2 + zpc2)
ppm = jnp.sqrt(xpa2 + ypb2 + zmc2)
mpp = jnp.sqrt(xma2 + ypb2 + zpc2)
ff2x = jnp.log((xma + mmm) * (xpa + ppm) * (xpa + pmp) * (xma + mpp))
ff2x = ff2x - jnp.log((xpa + pmm) * (xma + mpm) * (xma + mmp) * (xpa + ppp))
ff2y = jnp.log((-ymb + mmm) * (-ypb + ppm) * (-ymb + pmp) * (-ypb + mpp))
ff2y = ff2y - jnp.log((-ymb + pmm) * (-ypb + mpm) * (ymb - mmp) * (ypb - ppp))
ff2z = jnp.log((-zmc + mmm) * (-zmc + ppm) * (-zpc + pmp) * (-zpc + mpp))
ff2z = ff2z - jnp.log((-zmc + pmm) * (zmc - mpm) * (-zpc + mmp) * (zpc - ppp))
ff1x = (
jnp.arctan2(ymb * zmc, xma * mmm)
- jnp.arctan2(ymb * zmc, xpa * pmm)
- jnp.arctan2(ypb * zmc, xma * mpm)
+ jnp.arctan2(ypb * zmc, xpa * ppm)
- jnp.arctan2(ymb * zpc, xma * mmp)
+ jnp.arctan2(ymb * zpc, xpa * pmp)
+ jnp.arctan2(ypb * zpc, xma * mpp)
- jnp.arctan2(ypb * zpc, xpa * ppp)
)
ff1y = (
jnp.arctan2(xma * zmc, ymb * mmm)
- jnp.arctan2(xpa * zmc, ymb * pmm)
- jnp.arctan2(xma * zmc, ypb * mpm)
+ jnp.arctan2(xpa * zmc, ypb * ppm)
- jnp.arctan2(xma * zpc, ymb * mmp)
+ jnp.arctan2(xpa * zpc, ymb * pmp)
+ jnp.arctan2(xma * zpc, ypb * mpp)
- jnp.arctan2(xpa * zpc, ypb * ppp)
)
ff1z = (
jnp.arctan2(xma * ymb, zmc * mmm)
- jnp.arctan2(xpa * ymb, zmc * pmm)
- jnp.arctan2(xma * ypb, zmc * mpm)
+ jnp.arctan2(xpa * ypb, zmc * ppm)
- jnp.arctan2(xma * ymb, zpc * mmp)
+ jnp.arctan2(xpa * ymb, zpc * pmp)
+ jnp.arctan2(xma * ypb, zpc * mpp)
- jnp.arctan2(xpa * ypb, zpc * ppp)
)
bx_pol_x = pol_x * ff1x * qsigns[:, 0, 0]
by_pol_x = pol_x * ff2z * qsigns[:, 0, 1]
bz_pol_x = pol_x * ff2y * qsigns[:, 0, 2]
bx_pol_y = pol_y * ff2z * qsigns[:, 1, 0]
by_pol_y = pol_y * ff1y * qsigns[:, 1, 1]
bz_pol_y = -pol_y * ff2x * qsigns[:, 1, 2]
bx_pol_z = pol_z * ff2y * qsigns[:, 2, 0]
by_pol_z = -pol_z * ff2x * qsigns[:, 2, 1]
bz_pol_z = pol_z * ff1z * qsigns[:, 2, 2]
bx_tot = bx_pol_x + bx_pol_y + bx_pol_z
by_tot = by_pol_x + by_pol_y + by_pol_z
bz_tot = bz_pol_x + bz_pol_y + bz_pol_z
return jnp.stack((bx_tot, by_tot, bz_tot), axis=-1) / (4.0 * jnp.pi)
@jax.jit
def _cuboid_masks(
observers: jnp.ndarray,
dimensions: jnp.ndarray,
polarizations: jnp.ndarray,
rtol_surface: float = 1e-15,
) -> tuple[jnp.ndarray, jnp.ndarray]:
x, y, z = observers.T
a, b, c = jnp.abs(dimensions.T) / 2.0
pol_x, pol_y, pol_z = polarizations.T
mask_pol_not_null = ~((pol_x == 0.0) & (pol_y == 0.0) & (pol_z == 0.0))
mask_dim_not_null = (a * b * c) != 0.0
x_dist = jnp.abs(x) - a
y_dist = jnp.abs(y) - b
z_dist = jnp.abs(z) - c
mask_surf_x = jnp.abs(x_dist) < rtol_surface * a
mask_surf_y = jnp.abs(y_dist) < rtol_surface * b
mask_surf_z = jnp.abs(z_dist) < rtol_surface * c
mask_inside_x = x_dist < rtol_surface * a
mask_inside_y = y_dist < rtol_surface * b
mask_inside_z = z_dist < rtol_surface * c
mask_inside = mask_inside_x & mask_inside_y & mask_inside_z
mask_xedge = mask_surf_y & mask_surf_z & mask_inside_x
mask_yedge = mask_surf_x & mask_surf_z & mask_inside_y
mask_zedge = mask_surf_x & mask_surf_y & mask_inside_z
mask_not_edge = ~(mask_xedge | mask_yedge | mask_zedge)
mask_gen = mask_pol_not_null & mask_dim_not_null & mask_not_edge
return mask_inside, mask_gen
[docs]
@jax.jit
def magnet_cuboid_jfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""J-field for homogeneously polarized cuboids."""
obs = ensure_observers(observers)
dim = _broadcast_vector(jnp.asarray(dimensions, dtype=jnp.float64), obs.shape)
pol = _broadcast_vector(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape)
mask_inside, _ = _cuboid_masks(obs, dim, pol)
return jnp.where(mask_inside[:, None], pol, 0.0)
[docs]
@jax.jit
def magnet_cuboid_mfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""M-field for homogeneously polarized cuboids."""
return magnet_cuboid_jfield(observers, dimensions, polarizations) / MU0
[docs]
@jax.jit
def magnet_cuboid_hfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""H-field for homogeneously polarized cuboids."""
obs = ensure_observers(observers)
dim = _broadcast_vector(jnp.asarray(dimensions, dtype=jnp.float64), obs.shape)
pol = _broadcast_vector(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape)
mask_inside, mask_gen = _cuboid_masks(obs, dim, pol)
b_all = magnet_cuboid_bfield(obs, dim, pol)
b_out = jnp.where(mask_gen[:, None], b_all, 0.0)
h = b_out - jnp.where(mask_inside[:, None], pol, 0.0)
return h / MU0
[docs]
@jax.jit
def magnet_cylinder_axial_bfield(z0: jnp.ndarray, r: jnp.ndarray, z: jnp.ndarray) -> jnp.ndarray:
"""B-field in cylindrical coordinates for axially polarized cylinders."""
zph = z + z0
zmh = z - z0
dpr = 1.0 + r
dmr = 1.0 - r
sq0 = jnp.sqrt(zmh * zmh + dpr * dpr)
sq1 = jnp.sqrt(zph * zph + dpr * dpr)
k1 = jnp.sqrt((zph * zph + dmr * dmr) / (zph * zph + dpr * dpr))
k0 = jnp.sqrt((zmh * zmh + dmr * dmr) / (zmh * zmh + dpr * dpr))
gamma = dmr / dpr
one = jnp.ones_like(z0)
br = (cel(k1, one, one, -one) / sq1 - cel(k0, one, one, -one) / sq0) / jnp.pi
bz = (zph * cel(k1, gamma * gamma, one, gamma) / sq1) - (
zmh * cel(k0, gamma * gamma, one, gamma) / sq0
)
bz = bz / (dpr * jnp.pi)
return jnp.stack((br, jnp.zeros_like(br), bz), axis=-1)
[docs]
@jax.jit
def magnet_cylinder_diametral_hfield(
z0: jnp.ndarray,
r: jnp.ndarray,
z: jnp.ndarray,
phi: jnp.ndarray,
) -> jnp.ndarray:
"""H-field in cylindrical coordinates for diametral polarization."""
zp = z + z0
zm = z - z0
zp2 = zp * zp
zm2 = zm * zm
r2 = r * r
mask_small_r = r < 0.05
zpp = zp2 + 1.0
zmm = zm2 + 1.0
sqrt_p = jnp.sqrt(zpp)
sqrt_m = jnp.sqrt(zmm)
frac1 = zp / sqrt_p
frac2 = zm / sqrt_m
r3 = r2 * r
r4 = r3 * r
r5 = r4 * r
term1 = frac1 - frac2
term2 = (frac1 / zpp**2 - frac2 / zmm**2) * r2 / 8.0
term3 = (3.0 - 4.0 * zp2) * frac1 / zpp**4 - (3.0 - 4.0 * zm2) * frac2 / zmm**4
term3 = term3 * r4 / 64.0
hr_small = -jnp.cos(phi) / 4.0 * (term1 + 9.0 * term2 + 25.0 * term3)
hphi_small = jnp.sin(phi) / 4.0 * (term1 + 3.0 * term2 + 5.0 * term3)
hz_small = r * (1.0 / zpp / sqrt_p - 1.0 / zmm / sqrt_m)
hz_small = hz_small + (3.0 / 8.0) * r3 * (
(1.0 - 4.0 * zp2) / zpp**3 / sqrt_p - (1.0 - 4.0 * zm2) / zmm**3 / sqrt_m
)
hz_small = hz_small + (15.0 / 64.0) * r5 * (
(1.0 - 12.0 * zp2 + 8.0 * zp2 * zp2) / zpp**5 / sqrt_p
- (1.0 - 12.0 * zm2 + 8.0 * zm2 * zm2) / zmm**5 / sqrt_m
)
hz_small = -jnp.cos(phi) / 4.0 * hz_small
rp = r + 1.0
rm = r - 1.0
rp2 = rp * rp
rm2 = rm * rm
ap2 = zp2 + rm2
am2 = zm2 + rm2
ap = jnp.sqrt(ap2)
am = jnp.sqrt(am2)
argp = -4.0 * r / ap2
argm = -4.0 * r / am2
mask_special = rm == 0.0
argc = jnp.where(mask_special, 1e16, -4.0 * r / rm2)
one_over_rm = jnp.where(mask_special, 0.0, 1.0 / rm)
elle_p = ellipe(argp)
elle_m = ellipe(argm)
ellk_p = ellipk(argp)
ellk_m = ellipk(argm)
ellpi_p = ellippi(argc, argp)
ellpi_m = ellippi(argc, argm)
safe_r = jnp.where(r == 0.0, 1.0, r)
safe_r2 = safe_r * safe_r
hr_general = (
-jnp.cos(phi)
/ (4.0 * jnp.pi * safe_r2)
* (
-zm * am * elle_m
+ zp * ap * elle_p
+ zm / am * (2.0 + zm2) * ellk_m
- zp / ap * (2.0 + zp2) * ellk_p
+ (zm / am * ellpi_m - zp / ap * ellpi_p) * rp * (r2 + 1.0) * one_over_rm
)
)
hphi_general = (
jnp.sin(phi)
/ (4.0 * jnp.pi * safe_r2)
* (
+zm * am * elle_m
- zp * ap * elle_p
- zm / am * (2.0 + zm2 + 2.0 * r2) * ellk_m
+ zp / ap * (2.0 + zp2 + 2.0 * r2) * ellk_p
+ zm / am * rp2 * ellpi_m
- zp / ap * rp2 * ellpi_p
)
)
hz_general = (
-jnp.cos(phi)
/ (2.0 * jnp.pi * safe_r)
* (
+am * elle_m
- ap * elle_p
- (1.0 + zm2 + r2) / am * ellk_m
+ (1.0 + zp2 + r2) / ap * ellk_p
)
)
hr = jnp.where(mask_small_r, hr_small, hr_general)
hphi = jnp.where(mask_small_r, hphi_small, hphi_general)
hz = jnp.where(mask_small_r, hz_small, hz_general)
return jnp.stack((hr, hphi, hz), axis=-1)
@jax.jit
def _cylinder_masks(
r: jnp.ndarray,
z: jnp.ndarray,
z0: jnp.ndarray,
r0: jnp.ndarray,
polarization: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
pol_x, pol_y, pol_z = polarization.T
mask_dim_not_null = (r0 != 0.0) & (z0 != 0.0)
mask_between_bases = jnp.abs(z) <= z0
mask_inside_hull = r <= 1.0
mask_inside = mask_between_bases & mask_inside_hull & mask_dim_not_null
mask_on_hull = jnp.isclose(r, 1.0, rtol=1e-15, atol=0.0)
mask_on_bases = jnp.isclose(jnp.abs(z), z0, rtol=1e-15, atol=0.0)
mask_not_on_edge = ~(mask_on_hull & mask_on_bases)
mask_pol_not_null = ~((pol_x == 0.0) & (pol_y == 0.0) & (pol_z == 0.0))
mask_gen = mask_pol_not_null & mask_not_on_edge & mask_dim_not_null
mask_pol_tv = ((pol_x != 0.0) | (pol_y != 0.0)) & mask_gen
mask_pol_ax = (pol_z != 0.0) & mask_gen
mask_inside_gen = mask_inside & mask_gen
return mask_pol_tv, mask_pol_ax, mask_inside_gen, mask_dim_not_null
[docs]
@jax.jit
def magnet_cylinder_jfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""J-field for homogeneously polarized cylinders."""
obs = ensure_observers(observers)
dim = _broadcast_vector(jnp.asarray(dimensions, dtype=jnp.float64), (obs.shape[0], 2))
pol = _broadcast_vector(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape)
r, _, z = cart_to_cyl(obs)
r0, z0 = (dim / 2.0).T
safe_r0 = jnp.where(r0 == 0.0, 1.0, r0)
rs = r / safe_r0
zs = z / safe_r0
z0s = z0 / safe_r0
_, _, mask_inside, mask_dim_not_null = _cylinder_masks(rs, zs, z0s, r0, pol)
mask_inside = mask_inside & mask_dim_not_null
return jnp.where(mask_inside[:, None], pol, 0.0)
[docs]
@jax.jit
def magnet_cylinder_mfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""M-field for homogeneously polarized cylinders."""
return magnet_cylinder_jfield(observers, dimensions, polarizations) / MU0
[docs]
@jax.jit
def magnet_cylinder_bfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""B-field of homogeneously polarized cylinders centered at the origin."""
obs = ensure_observers(observers)
dim = _broadcast_vector(jnp.asarray(dimensions, dtype=jnp.float64), (obs.shape[0], 2))
pol = _broadcast_vector(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape)
r, phi, z = cart_to_cyl(obs)
r0, z0 = (dim / 2.0).T
safe_r0 = jnp.where(r0 == 0.0, 1.0, r0)
rs = r / safe_r0
zs = z / safe_r0
z0s = z0 / safe_r0
mask_pol_tv, mask_pol_ax, mask_inside, _ = _cylinder_masks(rs, zs, z0s, r0, pol)
pol_x, pol_y, pol_z = pol.T
pol_xy = jnp.sqrt(pol_x * pol_x + pol_y * pol_y)
theta = jnp.arctan2(pol_y, pol_x)
tv_cyl = magnet_cylinder_diametral_hfield(z0s, rs, zs, phi - theta) * pol_xy[:, None]
tv_cyl = jnp.where(mask_pol_tv[:, None], tv_cyl, 0.0)
ax_cyl = magnet_cylinder_axial_bfield(z0s, rs, zs) * pol_z[:, None]
ax_cyl = jnp.where(mask_pol_ax[:, None], ax_cyl, 0.0)
bh_cyl = tv_cyl + ax_cyl
b_cart = cyl_field_to_cart(phi, bh_cyl[:, 0], bh_cyl[:, 1], bh_cyl[:, 2])
mask_tv_inside = mask_pol_tv & mask_inside
bx = b_cart[:, 0] + jnp.where(mask_tv_inside, pol_x, 0.0)
by = b_cart[:, 1] + jnp.where(mask_tv_inside, pol_y, 0.0)
bz = b_cart[:, 2]
return jnp.stack((bx, by, bz), axis=-1)
[docs]
@jax.jit
def magnet_cylinder_hfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""H-field of homogeneously polarized cylinders centered at the origin."""
obs = ensure_observers(observers)
dim = _broadcast_vector(jnp.asarray(dimensions, dtype=jnp.float64), (obs.shape[0], 2))
pol = _broadcast_vector(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape)
r, phi, z = cart_to_cyl(obs)
r0, z0 = (dim / 2.0).T
safe_r0 = jnp.where(r0 == 0.0, 1.0, r0)
rs = r / safe_r0
zs = z / safe_r0
z0s = z0 / safe_r0
mask_pol_tv, mask_pol_ax, mask_inside, _ = _cylinder_masks(rs, zs, z0s, r0, pol)
pol_x, pol_y, pol_z = pol.T
pol_xy = jnp.sqrt(pol_x * pol_x + pol_y * pol_y)
theta = jnp.arctan2(pol_y, pol_x)
tv_cyl = magnet_cylinder_diametral_hfield(z0s, rs, zs, phi - theta) * pol_xy[:, None]
tv_cyl = jnp.where(mask_pol_tv[:, None], tv_cyl, 0.0)
ax_cyl = magnet_cylinder_axial_bfield(z0s, rs, zs) * pol_z[:, None]
ax_cyl = jnp.where(mask_pol_ax[:, None], ax_cyl, 0.0)
bh_cyl = tv_cyl + ax_cyl
h_cart = cyl_field_to_cart(phi, bh_cyl[:, 0], bh_cyl[:, 1], bh_cyl[:, 2])
mask_ax_inside = mask_pol_ax & mask_inside
hz = h_cart[:, 2] - jnp.where(mask_ax_inside, pol_z, 0.0)
return jnp.stack((h_cart[:, 0], h_cart[:, 1], hz), axis=-1) / MU0
[docs]
def dipole_bfield(observers: ArrayLike, moments: ArrayLike) -> jnp.ndarray:
"""B-field of a dipole (Tesla)."""
return jnp.asarray(MU0 * dipole_hfield(observers, moments), dtype=jnp.float64)
[docs]
def current_circle_bfield(
observers: ArrayLike,
diameter: ArrayLike,
current: ArrayLike,
) -> jnp.ndarray:
"""B-field of a current circle (Tesla)."""
return jnp.asarray(MU0 * current_circle_hfield(observers, diameter, current), dtype=jnp.float64)