"""Extended kernels for additional source families."""
from __future__ import annotations
import jax
import jax.numpy as jnp
from magpylib_jax._types import ArrayLike
from magpylib_jax.constants import MU0
from magpylib_jax.core.geometry import ensure_observers
_FOUR_PI = 4.0 * jnp.pi
_TETRA_FACES = jnp.array(
[
[0, 2, 1],
[0, 1, 3],
[1, 2, 3],
[0, 3, 2],
],
dtype=jnp.int32,
)
_IN_OUT_FLAGS = {"auto": 0, "inside": 1, "outside": 2}
_JIT_KERNEL_CACHE: dict[tuple[str, int, int], object] = {}
_JIT_SIMPLE_CACHE: dict[tuple[str, int], object] = {}
_JIT_MESH_CACHE: dict[tuple[str, int, int, int], object] = {}
_JIT_SEGMENT_CACHE: dict[tuple[str, int, int], object] = {}
def _broadcast_vec3(arr: jnp.ndarray, n: int) -> jnp.ndarray:
if arr.ndim == 1:
return jnp.broadcast_to(arr[None, :], (n, 3))
return jnp.broadcast_to(arr, (n, 3))
def _safe_norm(v: jnp.ndarray, axis: int = -1, keepdims: bool = False) -> jnp.ndarray:
return jnp.sqrt(jnp.maximum(jnp.sum(v * v, axis=axis, keepdims=keepdims), 1e-30))
[docs]
def magnet_sphere_bfield(
observers: ArrayLike,
diameters: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""B-field of homogeneously polarized spheres centered at the origin."""
obs = ensure_observers(observers)
n = obs.shape[0]
dia = jnp.asarray(diameters, dtype=jnp.float64)
if dia.ndim == 0:
dia = jnp.broadcast_to(dia, (n,))
else:
dia = jnp.broadcast_to(dia.reshape((-1,)), (n,))
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), n)
r = _safe_norm(obs, axis=1)
rs = jnp.abs(dia) / 2.0
outside = r > rs
b = (2.0 / 3.0) * pol
mdotr = jnp.sum(pol * obs, axis=1)
out_term = (
(3.0 * mdotr[:, None] * obs - pol * (r * r)[:, None])
* (rs**3 / 3.0)[:, None]
/ (r**5)[:, None]
)
out_term = jnp.where(outside[:, None], out_term, 0.0)
return jnp.where(outside[:, None], out_term, b)
def magnet_sphere_hfield(
observers: ArrayLike,
diameters: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
obs = ensure_observers(observers)
n = obs.shape[0]
dia = jnp.asarray(diameters, dtype=jnp.float64)
if dia.ndim == 0:
dia = jnp.broadcast_to(dia, (n,))
else:
dia = jnp.broadcast_to(dia.reshape((-1,)), (n,))
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), n)
r = _safe_norm(obs, axis=1)
rs = jnp.abs(dia) / 2.0
outside = r > rs
b = magnet_sphere_bfield(obs, dia, pol)
h = b - jnp.where(~outside[:, None], pol, 0.0)
return h / MU0
def magnet_sphere_jfield(
observers: ArrayLike,
diameters: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
obs = ensure_observers(observers)
n = obs.shape[0]
dia = jnp.asarray(diameters, dtype=jnp.float64)
if dia.ndim == 0:
dia = jnp.broadcast_to(dia, (n,))
else:
dia = jnp.broadcast_to(dia.reshape((-1,)), (n,))
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), n)
r = _safe_norm(obs, axis=1)
rs = jnp.abs(dia) / 2.0
inside = r <= rs
return jnp.where(inside[:, None], pol, 0.0)
def magnet_sphere_mfield(
observers: ArrayLike,
diameters: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
return magnet_sphere_jfield(observers, diameters, polarizations) / MU0
def _current_segment_hfield(
observers: jnp.ndarray,
segment_start: jnp.ndarray,
segment_end: jnp.ndarray,
current: jnp.ndarray,
) -> jnp.ndarray:
"""H-field for a single current segment."""
obs = ensure_observers(observers)
p1 = _broadcast_vec3(segment_start, obs.shape[0])
p2 = _broadcast_vec3(segment_end, obs.shape[0])
cur = jnp.asarray(current, dtype=jnp.float64)
if cur.ndim == 0:
cur = jnp.broadcast_to(cur, (obs.shape[0],))
else:
cur = jnp.broadcast_to(cur.reshape((-1,)), (obs.shape[0],))
seg = p1 - p2
norm12 = _safe_norm(seg, axis=1)
valid_seg = norm12 > 1e-15
p1s = p1 / norm12[:, None]
p2s = p2 / norm12[:, None]
pos = obs / norm12[:, None]
t = jnp.sum((pos - p1s) * (p1s - p2s), axis=1)
p4 = p1s + t[:, None] * (p1s - p2s)
o4 = pos - p4
norm_o4 = _safe_norm(o4, axis=1)
off_line = norm_o4 >= 1e-15
cros = jnp.cross(p2s - p1s, o4)
norm_cros = _safe_norm(cros, axis=1)
eB = cros / norm_cros[:, None]
norm_o1 = _safe_norm(pos - p1s, axis=1)
norm_o2 = _safe_norm(pos - p2s, axis=1)
norm_41 = _safe_norm(p4 - p1s, axis=1)
norm_42 = _safe_norm(p4 - p2s, axis=1)
sin1 = norm_41 / norm_o1
sin2 = norm_42 / norm_o2
mask2 = (norm_41 > 1.0) & (norm_41 > norm_42)
mask3 = (norm_42 > 1.0) & (norm_42 > norm_41)
delta = jnp.where(mask2, jnp.abs(sin1 - sin2), jnp.abs(sin1 + sin2))
delta = jnp.where(mask3, jnp.abs(sin2 - sin1), delta)
h = (delta / norm_o4)[:, None] * eB / norm12[:, None] * cur[:, None] / _FOUR_PI
valid = (
valid_seg & off_line & jnp.all(jnp.isfinite(p1), axis=1) & jnp.all(jnp.isfinite(p2), axis=1)
)
return jnp.where(valid[:, None], h, 0.0)
[docs]
def current_polyline_hfield(
observers: ArrayLike,
segments_start: ArrayLike,
segments_end: ArrayLike,
currents: ArrayLike,
) -> jnp.ndarray:
"""H-field of straight current segments."""
obs = ensure_observers(observers)
p1 = jnp.asarray(segments_start, dtype=jnp.float64)
p2 = jnp.asarray(segments_end, dtype=jnp.float64)
if p1.ndim == 1:
return _current_segment_hfield(obs, p1, p2, currents)
if p2.shape != p1.shape or p1.shape[-1] != 3:
raise ValueError("Polyline segments must have shape (n,3).")
cur = jnp.asarray(currents, dtype=jnp.float64)
if cur.ndim == 0:
cur = jnp.broadcast_to(cur, (p1.shape[0],))
else:
cur = jnp.broadcast_to(cur.reshape((-1,)), (p1.shape[0],))
h_segments = jax.vmap(lambda a, b, c: _current_segment_hfield(obs, a, b, c))(p1, p2, cur)
return jnp.sum(h_segments, axis=0)
def current_polyline_bfield(
observers: ArrayLike,
segments_start: ArrayLike,
segments_end: ArrayLike,
currents: ArrayLike,
) -> jnp.ndarray:
return MU0 * current_polyline_hfield(observers, segments_start, segments_end, currents)
[docs]
def current_polyline_bfield_masked(
observers: ArrayLike,
segments_start: ArrayLike,
segments_end: ArrayLike,
currents: ArrayLike,
segment_mask: ArrayLike,
) -> jnp.ndarray:
"""B-field of current segments with segment masking."""
obs = ensure_observers(observers)
p1 = jnp.asarray(segments_start, dtype=jnp.float64)
p2 = jnp.asarray(segments_end, dtype=jnp.float64)
cur = jnp.asarray(currents, dtype=jnp.float64)
if cur.ndim == 0:
cur = jnp.broadcast_to(cur, (p1.shape[0],))
else:
cur = jnp.broadcast_to(cur.reshape((-1,)), (p1.shape[0],))
mask = jnp.asarray(segment_mask, dtype=jnp.float64).reshape((-1,))
h_segments = jax.vmap(lambda a, b, c: _current_segment_hfield(obs, a, b, c))(p1, p2, cur)
h_segments = h_segments * mask[:, None, None]
return MU0 * jnp.sum(h_segments, axis=0)
def _current_polyline_bfield_segments_impl(
observers: jnp.ndarray,
segments_start: jnp.ndarray,
segments_end: jnp.ndarray,
currents: jnp.ndarray,
*,
n_segments: int,
) -> jnp.ndarray:
return current_polyline_bfield(observers, segments_start, segments_end, currents)
[docs]
def current_polyline_bfield_jit(
observers: ArrayLike,
segments_start: ArrayLike,
segments_end: ArrayLike,
currents: ArrayLike,
) -> jnp.ndarray:
"""JIT-specialized polyline B-field for fixed observer + segment counts."""
obs = ensure_observers(observers)
seg_start = jnp.asarray(segments_start, dtype=jnp.float64)
seg_end = jnp.asarray(segments_end, dtype=jnp.float64)
if seg_start.ndim == 1:
n_segments = 1
else:
n_segments = int(seg_start.shape[0])
jit_fn = _jit_kernel_segments(
"polyline_bfield", _current_polyline_bfield_segments_impl, obs.shape[0], n_segments
)
return jit_fn(
obs, seg_start, seg_end, jnp.asarray(currents, dtype=jnp.float64), n_segments=n_segments
)
[docs]
def current_circle_bfield_jit(
observers: ArrayLike,
diameter: ArrayLike,
current: ArrayLike,
) -> jnp.ndarray:
"""JIT-specialized circle B-field for fixed observer counts."""
from magpylib_jax.core.kernels import current_circle_bfield
obs = ensure_observers(observers)
dia = jnp.asarray(diameter, dtype=jnp.float64)
cur = jnp.asarray(current, dtype=jnp.float64)
jit_fn = _jit_kernel_simple("circle_bfield", current_circle_bfield, obs.shape[0])
return jit_fn(obs, dia, cur)
def _triangle_norm_vector(vertices: jnp.ndarray) -> jnp.ndarray:
a = vertices[:, 1] - vertices[:, 0]
b = vertices[:, 2] - vertices[:, 0]
n = jnp.cross(a, b)
n_norm = _safe_norm(n, axis=1)
return n / n_norm[:, None]
def _triangle_geom_terms(
tri: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Precompute triangle normals and edge terms for reuse."""
a = tri[..., 1, :] - tri[..., 0, :]
b = tri[..., 2, :] - tri[..., 0, :]
n = jnp.cross(a, b)
n_norm = _safe_norm(n, axis=-1, keepdims=True)
nvec = n / n_norm
# Use roll to avoid advanced-indexing inconsistencies under JIT tracing.
L = jnp.roll(tri, shift=-1, axis=-2) - tri
l2 = jnp.sum(L * L, axis=-1)
l1 = jnp.sqrt(l2)
return nvec, L, l1, l2
def _triangle_bfield_const_precomp(
obs: jnp.ndarray,
tri: jnp.ndarray,
pol: jnp.ndarray,
nvec: jnp.ndarray,
L: jnp.ndarray,
l1: jnp.ndarray,
l2: jnp.ndarray,
) -> jnp.ndarray:
"""B-field for constant triangle using precomputed geometry."""
R = tri[None, :, :] - obs[:, None, :]
r2 = jnp.sum(R * R, axis=-1)
r = jnp.sqrt(r2)
b = jnp.sum(R * L[None, :, :], axis=-1)
bl = b / l1
ind = jnp.abs(r + bl)
integ1 = 1.0 / l1 * jnp.log((jnp.sqrt(l2 + 2.0 * b + r2) + l1 + bl) / ind)
integ2 = -(1.0 / l1) * jnp.log(jnp.abs(l1 - r) / r)
integ = jnp.where(ind > 1e-12, integ1, integ2)
PQR = jnp.sum(integ[:, :, None] * L[None, :, :], axis=1)
sigma = jnp.sum(pol * nvec[None, :], axis=1)
B = sigma[:, None] * (nvec[None, :] * _solid_angle(R, r)[:, None] - jnp.cross(nvec, PQR))
B = B / (_FOUR_PI)
return jnp.nan_to_num(B, nan=0.0)
def _in_out_flag(in_out: str) -> int:
if in_out not in _IN_OUT_FLAGS:
raise ValueError(f"in_out must be one of {sorted(_IN_OUT_FLAGS)}, got {in_out!r}.")
return _IN_OUT_FLAGS[in_out]
def _jit_kernel(name: str, fn, n_obs: int, in_out_flag: int):
key = (name, int(n_obs), int(in_out_flag))
if key not in _JIT_KERNEL_CACHE:
_JIT_KERNEL_CACHE[key] = jax.jit(fn, static_argnames=("in_out_flag",))
return _JIT_KERNEL_CACHE[key]
def _jit_kernel_simple(name: str, fn, n_obs: int):
key = (name, int(n_obs))
if key not in _JIT_SIMPLE_CACHE:
_JIT_SIMPLE_CACHE[key] = jax.jit(fn)
return _JIT_SIMPLE_CACHE[key]
def _jit_kernel_mesh(name: str, fn, n_obs: int, n_faces: int, in_out_flag: int):
key = (name, int(n_obs), int(n_faces), int(in_out_flag))
if key not in _JIT_MESH_CACHE:
_JIT_MESH_CACHE[key] = jax.jit(fn, static_argnames=("in_out_flag", "n_faces"))
return _JIT_MESH_CACHE[key]
def _jit_kernel_segments(name: str, fn, n_obs: int, n_segments: int):
key = (name, int(n_obs), int(n_segments))
if key not in _JIT_SEGMENT_CACHE:
_JIT_SEGMENT_CACHE[key] = jax.jit(fn, static_argnames=("n_segments",))
return _JIT_SEGMENT_CACHE[key]
def _triangle_bfield_const_impl(
obs: jnp.ndarray,
tri: jnp.ndarray,
pol: jnp.ndarray,
) -> jnp.ndarray:
nvec, L, l1, l2 = _triangle_geom_terms(tri[None, :, :])
return _triangle_bfield_const_precomp(obs, tri, pol, nvec[0], L[0], l1[0], l2[0])
def _solid_angle(R: jnp.ndarray, r: jnp.ndarray) -> jnp.ndarray:
"""Solid angle for vectors R with shape (n,3,3) and norms r with shape (n,3)."""
N = jnp.sum(R[:, 2] * jnp.cross(R[:, 1], R[:, 0]), axis=1)
D = (
r[:, 0] * r[:, 1] * r[:, 2]
+ jnp.sum(R[:, 2] * R[:, 1], axis=1) * r[:, 0]
+ jnp.sum(R[:, 2] * R[:, 0], axis=1) * r[:, 1]
+ jnp.sum(R[:, 1] * R[:, 0], axis=1) * r[:, 2]
)
out = 2.0 * jnp.arctan2(N, D)
return jnp.where(jnp.abs(out) > 6.2831853, 0.0, out)
[docs]
def triangle_bfield(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""B-field of magnetically charged triangular surfaces."""
obs = ensure_observers(observers)
n = obs.shape[0]
tri = jnp.asarray(vertices, dtype=jnp.float64)
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), n)
if tri.ndim == 2:
tri_const = tri
nvec_const = _triangle_norm_vector(tri_const[None, :, :])[0]
sigma = jnp.sum(pol * nvec_const, axis=1)
R = tri_const[None, :, :] - obs[:, None, :]
L = jnp.stack(
(
tri_const[1] - tri_const[0],
tri_const[2] - tri_const[1],
tri_const[0] - tri_const[2],
),
axis=0,
)
l2 = jnp.sum(L * L, axis=-1)
l1 = jnp.sqrt(l2)
nvec = jnp.broadcast_to(nvec_const[None, :], (n, 3))
else:
tri = jnp.broadcast_to(tri, (n, 3, 3))
nvec = _triangle_norm_vector(tri)
sigma = jnp.sum(nvec * pol, axis=1)
R = tri - obs[:, None, :]
L = tri[:, (1, 2, 0)] - tri[:, (0, 1, 2)]
l2 = jnp.sum(L * L, axis=-1)
l1 = jnp.sqrt(l2)
r2 = jnp.sum(R * R, axis=-1)
r = jnp.sqrt(r2)
b = jnp.sum(R * L, axis=-1)
bl = b / l1
ind = jnp.abs(r + bl)
integ1 = 1.0 / l1 * jnp.log((jnp.sqrt(l2 + 2.0 * b + r2) + l1 + bl) / ind)
integ2 = -(1.0 / l1) * jnp.log(jnp.abs(l1 - r) / r)
integ = jnp.where(ind > 1e-12, integ1, integ2)
PQR = jnp.sum(integ[:, :, None] * L, axis=1)
B = sigma[:, None] * (nvec * _solid_angle(R, r)[:, None] - jnp.cross(nvec, PQR))
B = B / (_FOUR_PI)
return jnp.nan_to_num(B, nan=0.0)
def triangle_hfield(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
return triangle_bfield(observers, vertices, polarizations) / MU0
def triangle_jfield(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
obs = ensure_observers(observers)
return jnp.zeros_like(obs)
def triangle_mfield(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
obs = ensure_observers(observers)
return jnp.zeros_like(obs)
[docs]
def triangle_bfield_jit(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
) -> jnp.ndarray:
"""JIT-specialized triangle B-field for fixed observer counts."""
obs = ensure_observers(observers)
tri = jnp.asarray(vertices, dtype=jnp.float64)
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape[0])
if tri.ndim != 2:
return triangle_bfield(obs, tri, pol)
jit_fn = _jit_kernel_simple("triangle_bfield", _triangle_bfield_const_impl, obs.shape[0])
return jit_fn(obs, tri, pol)
def _check_tetra_chirality(vertices: jnp.ndarray) -> jnp.ndarray:
vecs = jnp.stack(
(
vertices[:, 1] - vertices[:, 0],
vertices[:, 2] - vertices[:, 0],
vertices[:, 3] - vertices[:, 0],
),
axis=-1,
)
dets = jnp.linalg.det(vecs)
swap = dets < 0
v = vertices
v_swapped = v.at[:, 2:4].set(v[:, 3:1:-1])
return jnp.where(swap[:, None, None], v_swapped, v)
def _points_inside_tetra(points: jnp.ndarray, vertices: jnp.ndarray) -> jnp.ndarray:
mat = jnp.transpose(vertices[:, 1:] - vertices[:, 0][:, None, :], (0, 2, 1))
inv = jnp.linalg.inv(mat)
delta = (points - vertices[:, 0])[:, :, None]
newp = jnp.matmul(inv, delta).squeeze(-1)
return (
jnp.all(newp >= 0.0, axis=1) & jnp.all(newp <= 1.0, axis=1) & (jnp.sum(newp, axis=1) <= 1.0)
)
def _points_inside_tetra_single(points: jnp.ndarray, vertices: jnp.ndarray) -> jnp.ndarray:
mat = (vertices[1:] - vertices[0]).T
inv = jnp.linalg.inv(mat)
delta = points - vertices[0]
newp = jnp.matmul(delta, inv.T)
return (
jnp.all(newp >= 0.0, axis=1) & jnp.all(newp <= 1.0, axis=1) & (jnp.sum(newp, axis=1) <= 1.0)
)
def _tetrahedron_bfield_const_impl(
obs: jnp.ndarray,
tet_const: jnp.ndarray,
pol: jnp.ndarray,
*,
in_out_flag: int,
) -> jnp.ndarray:
tet_const = _check_tetra_chirality(tet_const[None, :, :])[0]
faces = tet_const[_TETRA_FACES]
nvec, L, l1, l2 = _triangle_geom_terms(faces)
b_faces = jax.vmap(
_triangle_bfield_const_precomp,
in_axes=(None, 0, None, 0, 0, 0, 0),
)(obs, faces, pol, nvec, L, l1, l2)
b = jnp.sum(b_faces, axis=0)
if in_out_flag == _IN_OUT_FLAGS["outside"]:
inside = jnp.zeros((obs.shape[0],), dtype=bool)
elif in_out_flag == _IN_OUT_FLAGS["inside"]:
inside = jnp.ones((obs.shape[0],), dtype=bool)
else:
inside = _points_inside_tetra_single(obs, tet_const)
return b + jnp.where(inside[:, None], pol, 0.0)
def tetrahedron_bfield(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
obs = ensure_observers(observers)
n = obs.shape[0]
tet = jnp.asarray(vertices, dtype=jnp.float64)
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), n)
if tet.ndim == 2 or (tet.ndim == 3 and tet.shape[0] == 1):
tet_const = tet if tet.ndim == 2 else tet[0]
tet_const = _check_tetra_chirality(tet_const[None, :, :])[0]
faces = tet_const[_TETRA_FACES]
nvec, L, l1, l2 = _triangle_geom_terms(faces)
b_faces = jax.vmap(
_triangle_bfield_const_precomp,
in_axes=(None, 0, None, 0, 0, 0, 0),
)(obs, faces, pol, nvec, L, l1, l2)
b = jnp.sum(b_faces, axis=0)
if in_out == "inside":
inside = jnp.ones((n,), dtype=bool)
elif in_out == "outside":
inside = jnp.zeros((n,), dtype=bool)
else:
inside = _points_inside_tetra_single(obs, tet_const)
return b + jnp.where(inside[:, None], pol, 0.0)
tet = jnp.broadcast_to(tet, (n, 4, 3))
tet = _check_tetra_chirality(tet)
faces = tet[:, _TETRA_FACES, :]
b = jnp.sum(
jax.vmap(lambda tri: triangle_bfield(obs, tri, pol))(faces.swapaxes(0, 1)),
axis=0,
)
if in_out == "inside":
inside = jnp.ones((n,), dtype=bool)
elif in_out == "outside":
inside = jnp.zeros((n,), dtype=bool)
else:
inside = _points_inside_tetra(obs, tet)
return b + jnp.where(inside[:, None], pol, 0.0)
[docs]
def tetrahedron_bfield_jit(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
"""JIT-specialized tetrahedron B-field for fixed observer counts."""
obs = ensure_observers(observers)
tet = jnp.asarray(vertices, dtype=jnp.float64)
if tet.ndim == 3 and tet.shape[0] == 1:
tet = tet[0]
if tet.ndim != 2:
return tetrahedron_bfield(obs, tet, polarizations, in_out=in_out)
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape[0])
flag = _in_out_flag(in_out)
jit_fn = _jit_kernel(
"tetrahedron_bfield",
_tetrahedron_bfield_const_impl,
obs.shape[0],
flag,
)
return jit_fn(obs, tet, pol, in_out_flag=flag)
def tetrahedron_hfield(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
b = tetrahedron_bfield(observers, vertices, polarizations, in_out=in_out)
j = tetrahedron_jfield(observers, vertices, polarizations, in_out=in_out)
return (b - j) / MU0
def tetrahedron_jfield(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
obs = ensure_observers(observers)
n = obs.shape[0]
tet = jnp.asarray(vertices, dtype=jnp.float64)
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), n)
if tet.ndim == 2 or (tet.ndim == 3 and tet.shape[0] == 1):
tet_const = tet if tet.ndim == 2 else tet[0]
tet_const = _check_tetra_chirality(tet_const[None, :, :])[0]
if in_out == "inside":
inside = jnp.ones((n,), dtype=bool)
elif in_out == "outside":
inside = jnp.zeros((n,), dtype=bool)
else:
inside = _points_inside_tetra_single(obs, tet_const)
return jnp.where(inside[:, None], pol, 0.0)
tet = jnp.broadcast_to(tet, (n, 4, 3))
if in_out == "inside":
inside = jnp.ones((n,), dtype=bool)
elif in_out == "outside":
inside = jnp.zeros((n,), dtype=bool)
else:
inside = _points_inside_tetra(obs, tet)
return jnp.where(inside[:, None], pol, 0.0)
def tetrahedron_mfield(
observers: ArrayLike,
vertices: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
return tetrahedron_jfield(observers, vertices, polarizations, in_out=in_out) / MU0
def _moller_trumbore_hits(
point: jnp.ndarray,
triangles: jnp.ndarray,
ray_dir: jnp.ndarray,
*,
eps: float = 1e-12,
) -> jnp.ndarray:
"""Vectorized ray-triangle intersection flags for one point."""
v0 = triangles[:, 0]
v1 = triangles[:, 1]
v2 = triangles[:, 2]
e1 = v1 - v0
e2 = v2 - v0
h = jnp.cross(jnp.broadcast_to(ray_dir[None, :], e2.shape), e2)
a = jnp.sum(e1 * h, axis=1)
valid = jnp.abs(a) > eps
inv_a = jnp.where(valid, 1.0 / a, 0.0)
s = point[None, :] - v0
u = inv_a * jnp.sum(s * h, axis=1)
q = jnp.cross(s, e1)
v = inv_a * jnp.sum(jnp.broadcast_to(ray_dir[None, :], q.shape) * q, axis=1)
t = inv_a * jnp.sum(e2 * q, axis=1)
return valid & (u >= -eps) & (v >= -eps) & (u + v <= 1.0 + eps) & (t > eps)
def _point_on_triangles(
point: jnp.ndarray,
triangles: jnp.ndarray,
*,
eps: float = 1e-7,
) -> jnp.ndarray:
v0 = triangles[:, 0]
v1 = triangles[:, 1]
v2 = triangles[:, 2]
n = jnp.cross(v1 - v0, v2 - v0)
n_norm = _safe_norm(n, axis=1)
dist = jnp.abs(jnp.sum((point[None, :] - v0) * n, axis=1)) / n_norm
on_plane = dist <= eps
v0v1 = v1 - v0
v0v2 = v2 - v0
v0p = point[None, :] - v0
dot00 = jnp.sum(v0v1 * v0v1, axis=1)
dot01 = jnp.sum(v0v1 * v0v2, axis=1)
dot02 = jnp.sum(v0v1 * v0p, axis=1)
dot11 = jnp.sum(v0v2 * v0v2, axis=1)
dot12 = jnp.sum(v0v2 * v0p, axis=1)
denom = dot00 * dot11 - dot01 * dot01
inv = jnp.where(jnp.abs(denom) > 1e-16, 1.0 / denom, 0.0)
u = (dot11 * dot02 - dot01 * dot12) * inv
v = (dot00 * dot12 - dot01 * dot02) * inv
inside = (u >= -eps) & (v >= -eps) & (u + v <= 1.0 + eps)
return jnp.any(on_plane & inside & (n_norm > 1e-12))
def _point_inside_mesh(point: jnp.ndarray, triangles: jnp.ndarray) -> jnp.ndarray:
ray = jnp.array([0.737, 0.511, 0.442], dtype=jnp.float64)
ray = ray / _safe_norm(ray)
hits = _moller_trumbore_hits(point, triangles, ray)
count = jnp.sum(hits.astype(jnp.int32))
inside = (count % 2) == 1
return inside | _point_on_triangles(point, triangles)
def _v_norm2_jax(a: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(a * a, axis=-1)
def _v_norm_proj_jax(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
ab = jnp.sum(a * b, axis=-1)
return ab / jnp.sqrt(_v_norm2_jax(a) * _v_norm2_jax(b))
def _v_dot_cross3d_jax(a: jnp.ndarray, b: jnp.ndarray, c: jnp.ndarray) -> jnp.ndarray:
return jnp.sum(jnp.cross(a, b) * c, axis=-1)
def _lines_end_in_trimesh_jax(lines: jnp.ndarray, faces: jnp.ndarray) -> jnp.ndarray:
normals = jnp.cross(faces[:, 0] - faces[:, 2], faces[:, 1] - faces[:, 2])
normals = jnp.broadcast_to(normals, (lines.shape[0],) + normals.shape)
l0 = lines[:, 0][:, None, :]
l1 = lines[:, 1][:, None, :]
ref_pts = jnp.broadcast_to(faces[:, 2], (lines.shape[0], faces.shape[0], 3))
eps = 1e-16
coincide = _v_norm2_jax(l1 - ref_pts) < eps
ref_pts2 = jnp.broadcast_to(faces[:, 1], ref_pts.shape)
ref_pts = jnp.where(coincide[..., None], ref_pts2, ref_pts)
proj0 = _v_norm_proj_jax(l0 - ref_pts, normals)
proj1 = _v_norm_proj_jax(l1 - ref_pts, normals)
eps = 1e-7
plane_touch = jnp.abs(proj1) < eps
plane_cross = jnp.sign(proj0) != jnp.sign(proj1)
faces0 = faces[:, 0][None, :, :]
faces1 = faces[:, 1][None, :, :]
faces2 = faces[:, 2][None, :, :]
a = faces0 - l0
b = faces1 - l0
c = faces2 - l0
d = l1 - l0
area1 = _v_dot_cross3d_jax(a, b, d)
area2 = _v_dot_cross3d_jax(b, c, d)
area3 = _v_dot_cross3d_jax(c, a, d)
eps = 1e-12
pass_through_boundary = (jnp.abs(area1) < eps) | (jnp.abs(area2) < eps) | (jnp.abs(area3) < eps)
area1 = jnp.sign(area1)
area2 = jnp.sign(area2)
area3 = jnp.sign(area3)
pass_through_inside = (area1 == area2) & (area2 == area3)
pass_through = pass_through_boundary | pass_through_inside
result_cross = pass_through & plane_cross
result_touch = pass_through & plane_touch
inside1 = (jnp.sum(result_cross, axis=1) % 2) != 0
inside2 = jnp.any(result_touch, axis=1)
return inside1 | inside2
_MASK_FACE_SENTINEL = jnp.array(
((0.0, 0.0, 0.0), (1.0, 0.0, 0.0), (0.0, 1.0, 0.0)), dtype=jnp.float64
)
def _lines_end_in_trimesh_jax_masked(
lines: jnp.ndarray,
faces: jnp.ndarray,
face_mask: jnp.ndarray,
) -> jnp.ndarray:
mask = jnp.asarray(face_mask, dtype=bool)
faces_safe = jnp.where(mask[:, None, None], faces, _MASK_FACE_SENTINEL)
normals = jnp.cross(faces_safe[:, 0] - faces_safe[:, 2], faces_safe[:, 1] - faces_safe[:, 2])
normals = jnp.broadcast_to(normals, (lines.shape[0],) + normals.shape)
l0 = lines[:, 0][:, None, :]
l1 = lines[:, 1][:, None, :]
ref_pts = jnp.broadcast_to(faces_safe[:, 2], (lines.shape[0], faces_safe.shape[0], 3))
eps = 1e-16
coincide = _v_norm2_jax(l1 - ref_pts) < eps
ref_pts2 = jnp.broadcast_to(faces_safe[:, 1], ref_pts.shape)
ref_pts = jnp.where(coincide[..., None], ref_pts2, ref_pts)
proj0 = _v_norm_proj_jax(l0 - ref_pts, normals)
proj1 = _v_norm_proj_jax(l1 - ref_pts, normals)
eps = 1e-7
plane_touch = jnp.abs(proj1) < eps
plane_cross = jnp.sign(proj0) != jnp.sign(proj1)
faces0 = faces_safe[:, 0][None, :, :]
faces1 = faces_safe[:, 1][None, :, :]
faces2 = faces_safe[:, 2][None, :, :]
a = faces0 - l0
b = faces1 - l0
c = faces2 - l0
d = l1 - l0
area1 = _v_dot_cross3d_jax(a, b, d)
area2 = _v_dot_cross3d_jax(b, c, d)
area3 = _v_dot_cross3d_jax(c, a, d)
eps = 1e-12
pass_through_boundary = (jnp.abs(area1) < eps) | (jnp.abs(area2) < eps) | (jnp.abs(area3) < eps)
area1 = jnp.sign(area1)
area2 = jnp.sign(area2)
area3 = jnp.sign(area3)
pass_through_inside = (area1 == area2) & (area2 == area3)
pass_through = pass_through_boundary | pass_through_inside
mask_lines = mask[None, :]
result_cross = pass_through & plane_cross & mask_lines
result_touch = pass_through & plane_touch & mask_lines
inside1 = (jnp.sum(result_cross, axis=1) % 2) != 0
inside2 = jnp.any(result_touch, axis=1)
return inside1 | inside2
def _mask_inside_trimesh_jax(points: jnp.ndarray, faces: jnp.ndarray) -> jnp.ndarray:
vertices = faces.reshape((-1, 3))
xmin, ymin, zmin = jnp.min(vertices, axis=0)
xmax, ymax, zmax = jnp.max(vertices, axis=0)
eps = 1e-12
mx = (points[:, 0] < xmax + eps) & (points[:, 0] > xmin - eps)
my = (points[:, 1] < ymax + eps) & (points[:, 1] > ymin - eps)
mz = (points[:, 2] < zmax + eps) & (points[:, 2] > zmin - eps)
mask_box = mx & my & mz
start_point_outside = jnp.array([xmin, ymin, zmin], dtype=jnp.float64) - jnp.array(
[12.0012345, 5.9923456, 6.9932109], dtype=jnp.float64
)
start_pts = jnp.broadcast_to(start_point_outside, points.shape)
lines = jnp.stack((start_pts, points), axis=1)
mask_inside2 = _lines_end_in_trimesh_jax(lines, faces)
return mask_box & mask_inside2
def _mask_inside_trimesh_jax_masked(
points: jnp.ndarray,
faces: jnp.ndarray,
face_mask: jnp.ndarray,
) -> jnp.ndarray:
mask = jnp.asarray(face_mask, dtype=bool)
any_face = jnp.any(mask)
def _compute() -> jnp.ndarray:
verts = faces.reshape((-1, 3))
vert_mask = jnp.repeat(mask, 3)
big = 1.0e30
verts_min = jnp.where(vert_mask[:, None], verts, big)
verts_max = jnp.where(vert_mask[:, None], verts, -big)
xmin, ymin, zmin = jnp.min(verts_min, axis=0)
xmax, ymax, zmax = jnp.max(verts_max, axis=0)
eps = 1e-12
mx = (points[:, 0] < xmax + eps) & (points[:, 0] > xmin - eps)
my = (points[:, 1] < ymax + eps) & (points[:, 1] > ymin - eps)
mz = (points[:, 2] < zmax + eps) & (points[:, 2] > zmin - eps)
mask_box = mx & my & mz
start_point_outside = jnp.array([xmin, ymin, zmin], dtype=jnp.float64) - jnp.array(
[12.0012345, 5.9923456, 6.9932109], dtype=jnp.float64
)
start_pts = jnp.broadcast_to(start_point_outside, points.shape)
lines = jnp.stack((start_pts, points), axis=1)
mask_inside2 = _lines_end_in_trimesh_jax_masked(lines, faces, mask)
return mask_box & mask_inside2
def _empty() -> jnp.ndarray:
return jnp.zeros((points.shape[0],), dtype=bool)
return jax.lax.cond(any_face, _compute, _empty)
def _inside_mask_mesh(observers: jnp.ndarray, mesh: jnp.ndarray) -> jnp.ndarray:
if mesh.ndim == 3:
return _mask_inside_trimesh_jax(observers, mesh)
return jax.vmap(lambda obs, face: _mask_inside_trimesh_jax(obs[None, :], face)[0])(
observers, mesh
)
def _inside_mask_mesh_masked(
observers: jnp.ndarray,
mesh: jnp.ndarray,
face_mask: jnp.ndarray,
) -> jnp.ndarray:
if mesh.ndim == 3:
return _mask_inside_trimesh_jax_masked(observers, mesh, face_mask)
return jax.vmap(
lambda obs, face, mask: _mask_inside_trimesh_jax_masked(obs[None, :], face, mask)[0]
)(observers, mesh, face_mask)
def _broadcast_mesh(mesh: jnp.ndarray, n: int) -> jnp.ndarray:
if mesh.ndim == 3:
return jnp.broadcast_to(mesh[None, :, :, :], (n, *mesh.shape))
if mesh.ndim == 4:
return jnp.broadcast_to(mesh, (n, mesh.shape[1], 3, 3))
raise ValueError(f"Expected mesh shape (t,3,3) or (n,t,3,3), got {mesh.shape}.")
[docs]
def magnet_trimesh_bfield(
observers: ArrayLike,
mesh: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
"""B-field of uniformly polarized closed triangular meshes."""
obs = ensure_observers(observers)
n = obs.shape[0]
mesh_arr = jnp.asarray(mesh, dtype=jnp.float64)
if mesh_arr.ndim == 4:
mesh_arr = _broadcast_mesh(mesh_arr, n)
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), n)
# Evaluate each face as a batched triangle field and reduce over faces.
# This avoids flatten+repeat expansions and lowers peak memory pressure.
if mesh_arr.ndim == 3:
flag = _in_out_flag(in_out)
return _magnet_trimesh_bfield_const_impl(obs, mesh_arr, pol, in_out_flag=flag)
mesh_by_face = jnp.swapaxes(mesh_arr, 0, 1) # (n_faces, n_obs, 3, 3)
b_faces = jax.vmap(lambda face_vertices: triangle_bfield(obs, face_vertices, pol))(mesh_by_face)
b = jnp.sum(b_faces, axis=0)
if in_out == "outside":
inside = jnp.zeros((n,), dtype=bool)
elif in_out == "inside":
inside = jnp.ones((n,), dtype=bool)
else:
inside = _inside_mask_mesh(obs, mesh_arr)
return b + jnp.where(inside[:, None], pol, 0.0)
def _magnet_trimesh_bfield_const_impl(
obs: jnp.ndarray,
mesh_arr: jnp.ndarray,
pol: jnp.ndarray,
*,
in_out_flag: int,
) -> jnp.ndarray:
nvec, L, l1, l2 = _triangle_geom_terms(mesh_arr)
def _accumulate_faces() -> jnp.ndarray:
def body(i: int, acc: jnp.ndarray) -> jnp.ndarray:
return acc + _triangle_bfield_const_precomp(
obs, mesh_arr[i], pol, nvec[i], L[i], l1[i], l2[i]
)
init = jnp.zeros((obs.shape[0], 3), dtype=jnp.float64)
return jax.lax.fori_loop(0, mesh_arr.shape[0], body, init)
if mesh_arr.shape[0] <= 64:
b_faces = jax.vmap(
_triangle_bfield_const_precomp,
in_axes=(None, 0, None, 0, 0, 0, 0),
)(obs, mesh_arr, pol, nvec, L, l1, l2)
b = jnp.sum(b_faces, axis=0)
else:
b = _accumulate_faces()
if in_out_flag == _IN_OUT_FLAGS["outside"]:
inside = jnp.zeros((obs.shape[0],), dtype=bool)
elif in_out_flag == _IN_OUT_FLAGS["inside"]:
inside = jnp.ones((obs.shape[0],), dtype=bool)
else:
inside = _inside_mask_mesh(obs, mesh_arr)
return b + jnp.where(inside[:, None], pol, 0.0)
[docs]
def precompute_trimesh_geometry(
mesh: ArrayLike,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Precompute triangle mesh geometry terms for reuse."""
mesh_arr = jnp.asarray(mesh, dtype=jnp.float64)
if mesh_arr.ndim != 3 or mesh_arr.shape[1:] != (3, 3):
raise ValueError("Mesh must have shape (n_faces,3,3).")
nvec, L, l1, l2 = _triangle_geom_terms(mesh_arr)
return mesh_arr, nvec, L, l1, l2
def _magnet_trimesh_bfield_precomp_impl(
obs: jnp.ndarray,
mesh_arr: jnp.ndarray,
pol: jnp.ndarray,
nvec: jnp.ndarray,
L: jnp.ndarray,
l1: jnp.ndarray,
l2: jnp.ndarray,
*,
in_out_flag: int,
n_faces: int,
) -> jnp.ndarray:
def _accumulate_faces() -> jnp.ndarray:
def body(i: int, acc: jnp.ndarray) -> jnp.ndarray:
return acc + _triangle_bfield_const_precomp(
obs, mesh_arr[i], pol, nvec[i], L[i], l1[i], l2[i]
)
init = jnp.zeros((obs.shape[0], 3), dtype=jnp.float64)
return jax.lax.fori_loop(0, n_faces, body, init)
if n_faces <= 64:
b_faces = jax.vmap(
_triangle_bfield_const_precomp,
in_axes=(None, 0, None, 0, 0, 0, 0),
)(obs, mesh_arr, pol, nvec, L, l1, l2)
b = jnp.sum(b_faces, axis=0)
else:
b = _accumulate_faces()
if in_out_flag == _IN_OUT_FLAGS["outside"]:
inside = jnp.zeros((obs.shape[0],), dtype=bool)
elif in_out_flag == _IN_OUT_FLAGS["inside"]:
inside = jnp.ones((obs.shape[0],), dtype=bool)
else:
inside = _inside_mask_mesh(obs, mesh_arr)
return b + jnp.where(inside[:, None], pol, 0.0)
[docs]
def magnet_trimesh_bfield_precomp_masked(
observers: ArrayLike,
mesh: ArrayLike,
polarizations: ArrayLike,
nvec: ArrayLike,
L: ArrayLike,
l1: ArrayLike,
l2: ArrayLike,
face_mask: ArrayLike,
in_out_flag: int,
) -> jnp.ndarray:
"""B-field of triangular mesh using precomputed geometry with face masking."""
obs = ensure_observers(observers)
mesh_arr = jnp.asarray(mesh, dtype=jnp.float64)
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape[0])
nvec_arr = jnp.asarray(nvec, dtype=jnp.float64)
L_arr = jnp.asarray(L, dtype=jnp.float64)
l1_arr = jnp.asarray(l1, dtype=jnp.float64)
l2_arr = jnp.asarray(l2, dtype=jnp.float64)
mask = jnp.asarray(face_mask, dtype=bool).reshape((-1,))
n_faces = mesh_arr.shape[0]
def _accumulate_faces() -> jnp.ndarray:
def body(i: int, acc: jnp.ndarray) -> jnp.ndarray:
term = _triangle_bfield_const_precomp(
obs, mesh_arr[i], pol, nvec_arr[i], L_arr[i], l1_arr[i], l2_arr[i]
)
term = jnp.where(mask[i], term, 0.0)
return acc + term
init = jnp.zeros((obs.shape[0], 3), dtype=jnp.float64)
return jax.lax.fori_loop(0, n_faces, body, init)
if n_faces <= 64:
b_faces = jax.vmap(
_triangle_bfield_const_precomp,
in_axes=(None, 0, None, 0, 0, 0, 0),
)(obs, mesh_arr, pol, nvec_arr, L_arr, l1_arr, l2_arr)
b_faces = jnp.where(mask[:, None, None], b_faces, 0.0)
b = jnp.sum(b_faces, axis=0)
else:
b = _accumulate_faces()
inside = jax.lax.switch(
in_out_flag,
(
lambda: _inside_mask_mesh_masked(obs, mesh_arr, mask),
lambda: jnp.ones((obs.shape[0],), dtype=bool),
lambda: jnp.zeros((obs.shape[0],), dtype=bool),
),
)
return b + jnp.where(inside[:, None], pol, 0.0)
def _magnet_trimesh_bfield_faces_impl(
obs: jnp.ndarray,
mesh_arr: jnp.ndarray,
pol: jnp.ndarray,
*,
in_out_flag: int,
n_faces: int,
) -> jnp.ndarray:
return _magnet_trimesh_bfield_const_impl(obs, mesh_arr, pol, in_out_flag=in_out_flag)
[docs]
def magnet_trimesh_bfield_jit(
observers: ArrayLike,
mesh: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
"""JIT-specialized triangular mesh B-field for fixed observer counts."""
obs = ensure_observers(observers)
mesh_arr = jnp.asarray(mesh, dtype=jnp.float64)
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape[0])
if mesh_arr.ndim == 3:
return magnet_trimesh_bfield_jit_faces(obs, mesh_arr, pol, in_out=in_out)
flag = _in_out_flag(in_out)
jit_fn = _jit_kernel(
"triangularmesh_bfield",
_magnet_trimesh_bfield_const_impl,
obs.shape[0],
flag,
)
return jit_fn(obs, mesh_arr, pol, in_out_flag=flag)
[docs]
def magnet_trimesh_bfield_jit_faces(
observers: ArrayLike,
mesh: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
"""JIT-specialized triangular mesh B-field for fixed observer + face counts."""
obs = ensure_observers(observers)
mesh_arr = jnp.asarray(mesh, dtype=jnp.float64)
if mesh_arr.ndim != 3:
raise ValueError("TriangularMesh JIT expects mesh with shape (n_faces,3,3).")
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape[0])
flag = _in_out_flag(in_out)
n_faces = int(mesh_arr.shape[0])
jit_fn = _jit_kernel_mesh(
"triangularmesh_bfield_faces",
_magnet_trimesh_bfield_faces_impl,
obs.shape[0],
n_faces,
flag,
)
return jit_fn(obs, mesh_arr, pol, in_out_flag=flag, n_faces=n_faces)
[docs]
def magnet_trimesh_bfield_jit_faces_precomp(
observers: ArrayLike,
mesh: ArrayLike,
polarizations: ArrayLike,
nvec: ArrayLike,
L: ArrayLike,
l1: ArrayLike,
l2: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
"""JIT-specialized triangular mesh B-field using precomputed geometry."""
obs = ensure_observers(observers)
mesh_arr = jnp.asarray(mesh, dtype=jnp.float64)
if mesh_arr.ndim != 3:
raise ValueError("TriangularMesh JIT expects mesh with shape (n_faces,3,3).")
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape[0])
n_faces = int(mesh_arr.shape[0])
flag = _in_out_flag(in_out)
jit_fn = _jit_kernel_mesh(
"triangularmesh_bfield_precomp",
_magnet_trimesh_bfield_precomp_impl,
obs.shape[0],
n_faces,
flag,
)
return jit_fn(
obs,
mesh_arr,
pol,
jnp.asarray(nvec, dtype=jnp.float64),
jnp.asarray(L, dtype=jnp.float64),
jnp.asarray(l1, dtype=jnp.float64),
jnp.asarray(l2, dtype=jnp.float64),
in_out_flag=flag,
n_faces=n_faces,
)
def magnet_trimesh_hfield(
observers: ArrayLike,
mesh: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
b = magnet_trimesh_bfield(observers, mesh, polarizations, in_out=in_out)
j = magnet_trimesh_jfield(observers, mesh, polarizations, in_out=in_out)
return (b - j) / MU0
def magnet_trimesh_jfield(
observers: ArrayLike,
mesh: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
obs = ensure_observers(observers)
n = obs.shape[0]
mesh_arr = jnp.asarray(mesh, dtype=jnp.float64)
if mesh_arr.ndim == 4:
mesh_arr = _broadcast_mesh(mesh_arr, n)
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), n)
if in_out == "outside":
inside = jnp.zeros((n,), dtype=bool)
elif in_out == "inside":
inside = jnp.ones((n,), dtype=bool)
else:
inside = _inside_mask_mesh(obs, mesh_arr)
return jnp.where(inside[:, None], pol, 0.0)
def magnet_trimesh_mfield(
observers: ArrayLike,
mesh: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
return magnet_trimesh_jfield(observers, mesh, polarizations, in_out=in_out) / MU0
def _grid_to_triangles(grid: jnp.ndarray, *, flip: bool = False) -> jnp.ndarray:
a = grid[:-1, :-1, :]
b = grid[1:, :-1, :]
c = grid[:-1, 1:, :]
d = grid[1:, 1:, :]
t1 = jnp.stack((a, b, c), axis=-2).reshape((-1, 3, 3))
t2 = jnp.stack((b, d, c), axis=-2).reshape((-1, 3, 3))
tri = jnp.concatenate((t1, t2), axis=0)
if flip:
tri = tri[:, (0, 2, 1), :]
return tri
def _build_cylinder_segment_mesh(
dimension: jnp.ndarray,
*,
n_phi: int = 96,
n_r: int = 1,
n_z: int = 1,
) -> jnp.ndarray:
r1, r2, h, phi1_deg, phi2_deg = dimension
zmin = -h / 2.0
zmax = h / 2.0
phi1 = jnp.deg2rad(phi1_deg)
phi2 = jnp.deg2rad(phi2_deg)
phis = jnp.linspace(phi1, phi2, n_phi + 1, dtype=jnp.float64)
rs = jnp.linspace(r1, r2, n_r + 1, dtype=jnp.float64)
zs = jnp.linspace(zmin, zmax, n_z + 1, dtype=jnp.float64)
cos_p = jnp.cos(phis)
sin_p = jnp.sin(phis)
phi_grid = phis[:, None]
z_grid = zs[None, :]
outer = jnp.stack(
(
jnp.broadcast_to(r2 * jnp.cos(phi_grid), (n_phi + 1, n_z + 1)),
jnp.broadcast_to(r2 * jnp.sin(phi_grid), (n_phi + 1, n_z + 1)),
jnp.broadcast_to(z_grid, (n_phi + 1, n_z + 1)),
),
axis=-1,
)
inner = jnp.stack(
(
jnp.broadcast_to(r1 * jnp.cos(phi_grid), (n_phi + 1, n_z + 1)),
jnp.broadcast_to(r1 * jnp.sin(phi_grid), (n_phi + 1, n_z + 1)),
jnp.broadcast_to(z_grid, (n_phi + 1, n_z + 1)),
),
axis=-1,
)
r_grid = rs[:, None]
p_grid = phis[None, :]
top = jnp.stack(
(
r_grid * jnp.cos(p_grid),
r_grid * jnp.sin(p_grid),
jnp.broadcast_to(jnp.asarray(zmax), (n_r + 1, n_phi + 1)),
),
axis=-1,
)
bottom = jnp.stack(
(
r_grid * jnp.cos(p_grid),
r_grid * jnp.sin(p_grid),
jnp.broadcast_to(jnp.asarray(zmin), (n_r + 1, n_phi + 1)),
),
axis=-1,
)
r_cut = rs[:, None]
z_cut = zs[None, :]
cut1 = jnp.stack(
(
jnp.broadcast_to(r_cut * cos_p[0], (n_r + 1, n_z + 1)),
jnp.broadcast_to(r_cut * sin_p[0], (n_r + 1, n_z + 1)),
jnp.broadcast_to(z_cut, (n_r + 1, n_z + 1)),
),
axis=-1,
)
cut2 = jnp.stack(
(
jnp.broadcast_to(r_cut * cos_p[-1], (n_r + 1, n_z + 1)),
jnp.broadcast_to(r_cut * sin_p[-1], (n_r + 1, n_z + 1)),
jnp.broadcast_to(z_cut, (n_r + 1, n_z + 1)),
),
axis=-1,
)
parts = (
_grid_to_triangles(outer, flip=False),
_grid_to_triangles(inner, flip=True),
_grid_to_triangles(top, flip=False),
_grid_to_triangles(bottom, flip=True),
_grid_to_triangles(cut1, flip=False),
_grid_to_triangles(cut2, flip=True),
)
return jnp.concatenate(parts, axis=0)
[docs]
def precompute_cylinder_segment_geometry(
dimension: ArrayLike,
*,
n_phi: int = 96,
n_r: int = 1,
n_z: int = 1,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Precompute cylinder segment mesh + geometry terms."""
dim = jnp.asarray(dimension, dtype=jnp.float64)
mesh = _build_cylinder_segment_mesh(dim, n_phi=n_phi, n_r=n_r, n_z=n_z)
mesh_arr, nvec, L, l1, l2 = precompute_trimesh_geometry(mesh)
return mesh_arr, nvec, L, l1, l2
def _ensure_dim5(dimensions: ArrayLike, n: int) -> jnp.ndarray:
dim = jnp.asarray(dimensions, dtype=jnp.float64)
if dim.ndim == 1:
if dim.shape[0] != 5:
raise ValueError(f"CylinderSegment dimension must have shape (5,), got {dim.shape}.")
return dim
if dim.ndim == 2 and dim.shape[1] == 5:
if dim.shape[0] == 1:
return dim[0]
if dim.shape[0] == n:
first = dim[0]
same = jnp.all(jnp.abs(dim - first[None, :]) < 1e-14)
if bool(same):
return first
raise ValueError("Per-observer varying CylinderSegment dimensions are not supported.")
raise ValueError(f"CylinderSegment dimension must have shape (5,) or (n,5), got {dim.shape}.")
def magnet_cylinder_segment_bfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
obs = ensure_observers(observers)
dim = _ensure_dim5(dimensions, obs.shape[0])
mesh = _build_cylinder_segment_mesh(dim)
return magnet_trimesh_bfield(obs, mesh, polarizations, in_out=in_out)
[docs]
def magnet_cylinder_segment_bfield_jit(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
"""JIT-specialized cylinder-segment B-field for fixed observer counts."""
return magnet_cylinder_segment_bfield_jit_faces(
observers, dimensions, polarizations, in_out=in_out
)
[docs]
def magnet_cylinder_segment_bfield_jit_faces(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
"""JIT-specialized cylinder-segment B-field for fixed observer + face counts."""
obs = ensure_observers(observers)
dim = _ensure_dim5(dimensions, obs.shape[0])
mesh, nvec, L, l1, l2 = precompute_cylinder_segment_geometry(dim)
return magnet_trimesh_bfield_jit_faces_precomp(
obs, mesh, polarizations, nvec, L, l1, l2, in_out=in_out
)
def magnet_cylinder_segment_hfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
b = magnet_cylinder_segment_bfield(observers, dimensions, polarizations, in_out=in_out)
j = magnet_cylinder_segment_jfield(observers, dimensions, polarizations, in_out=in_out)
return (b - j) / MU0
def magnet_cylinder_segment_jfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
obs = ensure_observers(observers)
dim = _ensure_dim5(dimensions, obs.shape[0])
pol = _broadcast_vec3(jnp.asarray(polarizations, dtype=jnp.float64), obs.shape[0])
r1, r2, h, phi1_deg, phi2_deg = dim
phi1 = jnp.deg2rad(phi1_deg)
phi2 = jnp.deg2rad(phi2_deg)
x, y, z = obs.T
r = jnp.sqrt(x * x + y * y)
phi = jnp.arctan2(y, x)
phi = jnp.where(phi < 0, phi + 2.0 * jnp.pi, phi)
p1 = jnp.where(phi1 < 0, phi1 + 2.0 * jnp.pi, phi1)
p2 = jnp.where(phi2 < 0, phi2 + 2.0 * jnp.pi, phi2)
in_phi = jnp.where(p2 >= p1, (phi >= p1) & (phi <= p2), (phi >= p1) | (phi <= p2))
inside_geom = (r >= r1) & (r <= r2) & (jnp.abs(z) <= h / 2.0) & in_phi
if in_out == "inside":
inside = jnp.ones_like(inside_geom)
elif in_out == "outside":
inside = jnp.zeros_like(inside_geom)
else:
inside = inside_geom
return jnp.where(inside[:, None], pol, 0.0)
def magnet_cylinder_segment_mfield(
observers: ArrayLike,
dimensions: ArrayLike,
polarizations: ArrayLike,
in_out: str = "auto",
) -> jnp.ndarray:
return magnet_cylinder_segment_jfield(observers, dimensions, polarizations, in_out=in_out) / MU0
_TRI_Q_W = jnp.asarray(
[
0.2250000000000000,
0.1323941527885062,
0.1323941527885062,
0.1323941527885062,
0.1259391805448272,
0.1259391805448272,
0.1259391805448272,
],
dtype=jnp.float64,
)
_TRI_Q_L = jnp.asarray(
[
[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0],
[0.059715871789770, 0.470142064105115, 0.470142064105115],
[0.470142064105115, 0.059715871789770, 0.470142064105115],
[0.470142064105115, 0.470142064105115, 0.059715871789770],
[0.797426985353087, 0.101286507323456, 0.101286507323456],
[0.101286507323456, 0.797426985353087, 0.101286507323456],
[0.101286507323456, 0.101286507323456, 0.797426985353087],
],
dtype=jnp.float64,
)
def _triangle_barycentric_mask(
points: jnp.ndarray,
tri: jnp.ndarray,
normal: jnp.ndarray,
) -> jnp.ndarray:
a, b, c = tri
v0 = b - a
v1 = c - a
v2 = points - a[None, :]
d00 = jnp.dot(v0, v0)
d01 = jnp.dot(v0, v1)
d11 = jnp.dot(v1, v1)
d20 = jnp.sum(v2 * v0[None, :], axis=1)
d21 = jnp.sum(v2 * v1[None, :], axis=1)
denom = jnp.maximum(d00 * d11 - d01 * d01, 1e-30)
v = (d11 * d20 - d01 * d21) / denom
w = (d00 * d21 - d01 * d20) / denom
u = 1.0 - v - w
dist = jnp.abs(jnp.sum((points - a[None, :]) * normal[None, :], axis=1))
return (dist < 1e-10) & (u >= -1e-10) & (v >= -1e-10) & (w >= -1e-10)
def _rot_x(theta: jnp.ndarray) -> jnp.ndarray:
c = jnp.cos(theta)
s = jnp.sin(theta)
return jnp.asarray([[1.0, 0.0, 0.0], [0.0, c, -s], [0.0, s, c]], dtype=jnp.float64)
def _rot_z(alpha: jnp.ndarray) -> jnp.ndarray:
c = jnp.cos(alpha)
s = jnp.sin(alpha)
return jnp.asarray([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=jnp.float64)
def _triangle_coordinate_transform(
tri: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Transform a triangle to elementar sheet coordinates.
Returns (u1, u2, v2) coordinates, translation, and rotation matrix.
"""
a, b, c = tri
translation = a
b1 = b - a
c1 = c - a
theta = -jnp.arctan2(b1[2], b1[1])
r21 = _rot_x(theta)
b2 = r21 @ b1
c2 = r21 @ c1
alpha = -jnp.arctan2(b2[1], b2[0])
r22 = _rot_z(alpha)
b3 = r22 @ b2
c3 = r22 @ c2
psi = -jnp.arctan2(c3[2], c3[1])
r3 = _rot_x(psi)
c4 = r3 @ c3
rotation = r3 @ r22 @ r21
coords = jnp.asarray([b3[0], c4[0], c4[1]], dtype=jnp.float64)
return coords, translation, rotation
def _safe_sqrt(x: jnp.ndarray) -> jnp.ndarray:
return jnp.sqrt(jnp.maximum(x, 0.0))
def _safe_atanh(x: jnp.ndarray) -> jnp.ndarray:
eps = 1e-15
return jnp.arctanh(jnp.clip(x, -1.0 + eps, 1.0 - eps))
def _safe_logabs(x: jnp.ndarray) -> jnp.ndarray:
return jnp.log(jnp.maximum(jnp.abs(x), 1e-30))
def _elementar_current_sheet_hfield(
observers: jnp.ndarray,
coordinates: jnp.ndarray,
current_densities: jnp.ndarray,
) -> jnp.ndarray:
"""H-field for elementar current sheet in local coordinates."""
num_tol = 1e-10
x, y, z = observers.T
u1, u2, v2 = coordinates
ju, jv = current_densities
in_plane = jnp.abs(z) < num_tol
critical_value01 = (x * v2 - y * u2) / (u1 * v2)
critical_value02 = y / v2
critical_value1 = jnp.abs(y)
critical_value2 = jnp.abs(u2 * y - v2 * x)
critical_value3 = jnp.abs(v2 * (x - u1) + y * (u1 - u2))
mask0 = (
in_plane
& (critical_value01 + critical_value02 <= 1.0 + num_tol)
& (critical_value01 >= -num_tol)
& (critical_value02 >= -num_tol)
)
mask1 = in_plane & (critical_value1 < num_tol) & (~mask0)
mask2 = in_plane & (critical_value2 < num_tol) & (~mask0)
mask3 = in_plane & (critical_value3 < num_tol) & (~mask0)
mask_plane = ~(mask0 | mask1 | mask2 | mask3) & in_plane
mask_general = ~in_plane
sqrt1 = _safe_sqrt(x**2 + y**2 + z**2)
sqrt2 = _safe_sqrt(u1**2 - 2 * u1 * x + x**2 + y**2 + z**2)
sqrt3 = _safe_sqrt(u2**2 - 2 * u2 * x + v2**2 - 2 * v2 * y + x**2 + y**2 + z**2)
sqrt4 = _safe_sqrt(u1**2 - 2 * u1 * u2 + u2**2 + v2**2)
sqrt5 = _safe_sqrt(u2**2 + v2**2)
hx_general = (
jnp.arctan((-u2 * (y**2 + z**2) + v2 * x * y) / (v2 * z * sqrt1))
+ jnp.arctan((v2 * y * (u1 - x) - (u1 - u2) * (y**2 + z**2)) / (v2 * z * sqrt2))
- jnp.arctan((-u2 * (y**2 + z**2) - v2**2 * x + v2 * y * (u2 + x)) / (v2 * z * sqrt3))
- jnp.arctan(
(
-u1 * (v2**2 - 2 * v2 * y + y**2 + z**2)
+ u2 * (y**2 + z**2)
+ v2**2 * x
- v2 * y * (u2 + x)
)
/ (v2 * z * sqrt3)
)
) / (u1 * v2 * z)
hz_general = -(
ju * _safe_atanh(x / sqrt1)
+ ju * _safe_atanh((u1 - x) / sqrt2)
- (ju * (u1 - u2) - jv * v2)
* _safe_atanh((u1**2 - u1 * (u2 + x) + u2 * x + v2 * y) / (sqrt4 * sqrt2))
/ sqrt4
+ (ju * (u1 - u2) - jv * v2)
* _safe_atanh((u1 * (u2 - x) - u2**2 + u2 * x + v2 * (-v2 + y)) / (sqrt4 * sqrt3))
/ sqrt4
+ (ju * u2 + jv * v2) * _safe_atanh((-u2 * x - v2 * y) / (sqrt5 * sqrt1)) / sqrt5
- (ju * u2 + jv * v2)
* _safe_atanh((u2**2 - u2 * x + v2 * (v2 - y)) / (sqrt5 * sqrt3))
/ sqrt5
) / (u1 * v2)
sqrt_xy = _safe_sqrt(x**2 + y**2)
sqrt_u1 = _safe_sqrt(u1**2 - 2 * u1 * x + x**2 + y**2)
sqrt_u2 = _safe_sqrt(u2**2 - 2 * u2 * x + v2**2 - 2 * v2 * y + x**2 + y**2)
sqrt_u12 = _safe_sqrt(u1**2 - 2 * u1 * u2 + u2**2 + v2**2)
sqrt_u2v2 = _safe_sqrt(u2**2 + v2**2)
hz_plane = -(
ju * _safe_atanh(x / sqrt_xy)
+ ju * _safe_atanh((u1 - x) / sqrt_u1)
- (ju * (u1 - u2) - jv * v2)
* _safe_atanh((u1**2 - u1 * (u2 + x) + u2 * x + v2 * y) / (sqrt_u12 * sqrt_u1))
/ sqrt_u12
+ (ju * (u1 - u2) - jv * v2)
* _safe_atanh((u1 * (u2 - x) - u2**2 + u2 * x + v2 * (-v2 + y)) / (sqrt_u12 * sqrt_u2))
/ sqrt_u12
+ (ju * u2 + jv * v2) * _safe_atanh((-u2 * x - v2 * y) / (sqrt_u2v2 * sqrt_xy)) / sqrt_u2v2
- (ju * u2 + jv * v2)
* _safe_atanh((u2**2 - u2 * x + v2 * (v2 - y)) / (sqrt_u2v2 * sqrt_u2))
/ sqrt_u2v2
) / (u1 * v2)
hz_edge1 = (
-ju * x * _safe_logabs(x) / _safe_sqrt(x**2)
- ju * (u1 - x) * _safe_logabs(-u1 + x) / _safe_sqrt((u1 - x) ** 2)
+ (ju * (u1 - u2) - jv * v2)
* _safe_atanh(
(u1 * (-u2 + x) + u2**2 - u2 * x + v2**2)
/ (sqrt_u12 * _safe_sqrt(u2**2 - 2 * u2 * x + v2**2 + x**2))
)
/ sqrt_u12
+ (ju * (u1 - u2) - jv * v2)
* _safe_atanh((u1 - u2) * (u1 - x) / (sqrt_u12 * _safe_sqrt((u1 - x) ** 2)))
/ sqrt_u12
+ (ju * u2 + jv * v2)
* _safe_atanh(
(u2**2 - u2 * x + v2**2) / (sqrt_u2v2 * _safe_sqrt(u2**2 - 2 * u2 * x + v2**2 + x**2))
)
/ sqrt_u2v2
- (ju * u2 + jv * v2)
* _safe_atanh(u2 * (u1 - x) / (sqrt_u2v2 * _safe_sqrt((u1 - x) ** 2)))
/ sqrt_u2v2
) / (u1 * v2)
hz_edge2 = (
-ju
* _safe_atanh(
(u1 * v2 - u2 * y)
/ (v2 * _safe_sqrt(u1**2 - 2 * u1 * u2 * y / v2 + y**2 * (u2**2 / v2**2 + 1)))
)
+ ju
* _safe_atanh(u2 * (v2 - y) / (v2 * _safe_sqrt((u2**2 + v2**2) * (v2 - y) ** 2 / v2**2)))
+ (ju * (u1 - u2) - jv * v2)
* _safe_atanh(
(u1**2 * v2 - u1 * u2 * (v2 + y) + y * (u2**2 + v2**2))
/ (
v2
* _safe_sqrt(u1**2 - 2 * u1 * u2 * y / v2 + y**2 * (u2**2 / v2**2 + 1))
* sqrt_u12
)
)
/ sqrt_u12
+ (ju * (u1 - u2) - jv * v2)
* _safe_atanh(
(v2 - y)
* (-u1 * u2 + u2**2 + v2**2)
/ (v2 * _safe_sqrt((u2**2 + v2**2) * (v2 - y) ** 2 / v2**2) * sqrt_u12)
)
/ sqrt_u12
+ y
* (ju * u2 + jv * v2)
* _safe_logabs(y * (-(u2**2) - v2**2))
/ (v2 * _safe_sqrt(y**2 * (u2**2 + v2**2) / v2**2))
+ (v2 - y)
* (ju * u2 + jv * v2)
* _safe_logabs((u2**2 + v2**2) * (v2 - y))
/ (v2 * _safe_sqrt((u2**2 + v2**2) * (v2 - y) ** 2 / v2**2))
) / (u1 * v2)
hz_edge3 = (
ju
* v2
* _safe_atanh(
(u1 * (-v2 + y) - u2 * y)
/ (
v2
* _safe_sqrt(
(u1**2 * (v2 - y) ** 2 + 2 * u1 * u2 * y * (v2 - y) + y**2 * (u2**2 + v2**2))
/ v2**2
)
)
)
+ ju
* v2
* _safe_atanh(
(u1 - u2)
* (v2 - y)
/ (v2 * _safe_sqrt((v2 - y) ** 2 * (u1**2 - 2 * u1 * u2 + u2**2 + v2**2) / v2**2))
)
- v2
* (ju * u2 + jv * v2)
* _safe_atanh(
(u1 * u2 * (-v2 + y) + y * (-(u2**2) - v2**2))
/ (
v2
* _safe_sqrt(
(u1**2 * (v2 - y) ** 2 + 2 * u1 * u2 * y * (v2 - y) + y**2 * (u2**2 + v2**2))
/ v2**2
)
* sqrt_u2v2
)
)
/ sqrt_u2v2
+ v2
* (ju * u2 + jv * v2)
* _safe_atanh(
(v2 - y)
* (-u1 * u2 + u2**2 + v2**2)
/ (
v2
* _safe_sqrt((v2 - y) ** 2 * (u1**2 - 2 * u1 * u2 + u2**2 + v2**2) / v2**2)
* sqrt_u2v2
)
)
/ sqrt_u2v2
- y
* (ju * (-u1 + u2) + jv * v2)
* _safe_logabs(y * (-(u1**2) + 2 * u1 * u2 - u2**2 - v2**2))
/ _safe_sqrt(y**2 * (u1**2 - 2 * u1 * u2 + u2**2 + v2**2) / v2**2)
- (v2 - y)
* (ju * (-u1 + u2) + jv * v2)
* _safe_logabs((v2 - y) * (u1**2 - 2 * u1 * u2 + u2**2 + v2**2))
/ _safe_sqrt((v2 - y) ** 2 * (u1**2 - 2 * u1 * u2 + u2**2 + v2**2) / v2**2)
) / (u1 * v2**2)
hx = jnp.where(mask_general, hx_general, 0.0)
hz = jnp.where(mask_general, hz_general, 0.0)
hz = jnp.where(mask_plane, hz_plane, hz)
hz = jnp.where(mask1, hz_edge1, hz)
hz = jnp.where(mask2, hz_edge2, hz)
hz = jnp.where(mask3, hz_edge3, hz)
scale = (u1 * v2) / _FOUR_PI
hx_scaled = hx * jv * z * scale
hy_scaled = hx * (-ju) * z * scale
hz_scaled = hz * scale
return jnp.stack((hx_scaled, hy_scaled, hz_scaled), axis=1)
def _current_triangle_sheet_hfield_obs(
obs: jnp.ndarray,
tri: jnp.ndarray,
cd: jnp.ndarray,
) -> jnp.ndarray:
coords, translation, rotation = _triangle_coordinate_transform(tri)
obs_loc = (obs - translation[None, :]) @ rotation.T
cd_loc = (rotation @ cd)[:2]
u1, u2, v2 = coords
degenerate = (
jnp.isnan(u1)
| jnp.isnan(u2)
| jnp.isnan(v2)
| (jnp.abs(u1) < 1e-15)
| (jnp.abs(v2) < 1e-15)
)
h_local = _elementar_current_sheet_hfield(obs_loc, coords, cd_loc)
h_local = jnp.where(degenerate, 0.0, h_local)
return h_local @ rotation
def current_triangle_sheet_hfield(
observers: ArrayLike,
vertices: ArrayLike,
current_densities: ArrayLike,
) -> jnp.ndarray:
obs = ensure_observers(observers)
tri = jnp.asarray(vertices, dtype=jnp.float64)
if tri.shape != (3, 3):
raise ValueError(f"Triangle sheet vertices must have shape (3,3), got {tri.shape}.")
cd = jnp.asarray(current_densities, dtype=jnp.float64)
if cd.shape != (3,):
raise ValueError(f"Triangle sheet current density must have shape (3,), got {cd.shape}.")
return _current_triangle_sheet_hfield_obs(obs, tri, cd)
def current_trisheet_hfield(
observers: ArrayLike,
vertices: ArrayLike,
faces: ArrayLike,
current_densities: ArrayLike,
) -> jnp.ndarray:
obs = ensure_observers(observers)
verts = jnp.asarray(vertices, dtype=jnp.float64)
facs = jnp.asarray(faces, dtype=jnp.int32)
cds = jnp.asarray(current_densities, dtype=jnp.float64)
tris = verts[facs]
if tris.ndim != 3 or tris.shape[1:] != (3, 3):
raise ValueError(
"TriangleSheet requires faces indexing into vertices yielding shape (n,3,3)."
)
if cds.ndim != 2 or cds.shape[1] != 3:
raise ValueError("TriangleSheet current_densities must have shape (n,3).")
if cds.shape[0] != tris.shape[0]:
raise ValueError("TriangleSheet current_densities and faces length mismatch.")
h_faces = jax.vmap(lambda tri, cd: _current_triangle_sheet_hfield_obs(obs, tri, cd))(tris, cds)
return jnp.sum(h_faces, axis=0)
def current_trisheet_bfield(
observers: ArrayLike,
vertices: ArrayLike,
faces: ArrayLike,
current_densities: ArrayLike,
) -> jnp.ndarray:
return MU0 * current_trisheet_hfield(observers, vertices, faces, current_densities)
[docs]
def current_trisheet_bfield_masked(
observers: ArrayLike,
triangles: ArrayLike,
current_densities: ArrayLike,
face_mask: ArrayLike,
) -> jnp.ndarray:
"""B-field of triangle sheet with face masking."""
obs = ensure_observers(observers)
tris = jnp.asarray(triangles, dtype=jnp.float64)
cds = jnp.asarray(current_densities, dtype=jnp.float64)
mask = jnp.asarray(face_mask, dtype=jnp.float64).reshape((-1,))
h_faces = jax.vmap(lambda tri, cd: _current_triangle_sheet_hfield_obs(obs, tri, cd))(tris, cds)
h_faces = h_faces * mask[:, None, None]
return MU0 * jnp.sum(h_faces, axis=0)
[docs]
def current_trisheet_bfield_jit(
observers: ArrayLike,
vertices: ArrayLike,
faces: ArrayLike,
current_densities: ArrayLike,
) -> jnp.ndarray:
"""JIT-specialized triangle sheet B-field for fixed observer counts."""
obs = ensure_observers(observers)
verts = jnp.asarray(vertices, dtype=jnp.float64)
facs = jnp.asarray(faces, dtype=jnp.int32)
cds = jnp.asarray(current_densities, dtype=jnp.float64)
jit_fn = _jit_kernel_simple("trianglesheet_bfield", current_trisheet_bfield, obs.shape[0])
return jit_fn(obs, verts, facs, cds)
def _strip_triangles(vertices: jnp.ndarray) -> jnp.ndarray:
return jnp.stack((vertices[:-2], vertices[1:-1], vertices[2:]), axis=1)
def _strip_current_densities(vertices: jnp.ndarray, current: jnp.ndarray) -> jnp.ndarray:
tris = _strip_triangles(vertices)
v1 = tris[:, 1] - tris[:, 0]
v2 = tris[:, 2] - tris[:, 0]
v1v1 = jnp.sum(v1 * v1, axis=1)
v2v2 = jnp.sum(v2 * v2, axis=1)
v1v2 = jnp.sum(v1 * v2, axis=1)
denom = jnp.maximum(v2v2, 1e-30)
h = jnp.sqrt(jnp.maximum(v1v1 - (v1v2 * v1v2) / denom, 0.0))
valid = (v2v2 > 1e-15) & (v1v1 > 1e-15) & (h > 1e-15)
scale = jnp.where(valid, current / (jnp.sqrt(jnp.maximum(v2v2, 1e-30)) * h), 0.0)
cds = v2 * scale[:, None]
return jnp.where(valid[:, None], cds, 0.0)
def current_tristrip_hfield(
observers: ArrayLike,
vertices: ArrayLike,
current: ArrayLike,
) -> jnp.ndarray:
obs = ensure_observers(observers)
verts = jnp.asarray(vertices, dtype=jnp.float64)
if verts.ndim != 2 or verts.shape[1] != 3 or verts.shape[0] < 3:
raise ValueError("TriangleStrip vertices must have shape (n>=3,3).")
cur = jnp.asarray(current, dtype=jnp.float64).reshape(())
tris = _strip_triangles(verts)
cds = _strip_current_densities(verts, cur)
h_faces = jax.vmap(lambda tri, cd: _current_triangle_sheet_hfield_obs(obs, tri, cd))(tris, cds)
return jnp.sum(h_faces, axis=0)
def current_tristrip_bfield(
observers: ArrayLike,
vertices: ArrayLike,
current: ArrayLike,
) -> jnp.ndarray:
return MU0 * current_tristrip_hfield(observers, vertices, current)
[docs]
def current_tristrip_bfield_jit(
observers: ArrayLike,
vertices: ArrayLike,
current: ArrayLike,
) -> jnp.ndarray:
"""JIT-specialized triangle strip B-field for fixed observer counts."""
obs = ensure_observers(observers)
verts = jnp.asarray(vertices, dtype=jnp.float64)
curr = jnp.asarray(current, dtype=jnp.float64)
jit_fn = _jit_kernel_simple("trianglestrip_bfield", current_tristrip_bfield, obs.shape[0])
return jit_fn(obs, verts, curr)