"""Geometry helpers for frame transforms and coordinate conversions."""
from __future__ import annotations
import jax.numpy as jnp
from magpylib_jax._types import ArrayLike
def _as_array3(x: ArrayLike) -> jnp.ndarray:
arr = jnp.asarray(x, dtype=jnp.float64)
if arr.shape[-1] != 3:
raise ValueError(f"Expected trailing dimension 3, got shape {arr.shape}.")
return arr
[docs]
def ensure_observers(observers: ArrayLike) -> jnp.ndarray:
"""Normalize observers to a rank-2 array of shape (n, 3)."""
arr = _as_array3(observers)
if arr.ndim == 1:
return arr[None, :]
return arr.reshape((-1, 3))
[docs]
def normalize_orientation(orientation: ArrayLike | None) -> jnp.ndarray:
"""Return a 3x3 rotation matrix."""
if orientation is None:
return jnp.eye(3, dtype=jnp.float64)
if hasattr(orientation, "as_matrix"):
mat = jnp.asarray(orientation.as_matrix(), dtype=jnp.float64)
if mat.ndim == 3:
if mat.shape[0] != 1:
raise ValueError(
"Expected single orientation for this context, "
f"got {mat.shape[0]} orientations."
)
return mat[0]
if mat.shape != (3, 3):
raise ValueError(f"Expected orientation matrix with shape (3, 3), got {mat.shape}.")
return mat
ori = jnp.asarray(orientation, dtype=jnp.float64)
if ori.ndim == 3:
if ori.shape[0] != 1 or ori.shape[1:] != (3, 3):
raise ValueError(
f"Expected single orientation matrix with shape (3, 3), got {ori.shape}."
)
return ori[0]
if ori.shape != (3, 3):
raise ValueError(f"Expected orientation matrix with shape (3, 3), got {ori.shape}.")
return ori
[docs]
def normalize_positions(position: ArrayLike = (0.0, 0.0, 0.0)) -> jnp.ndarray:
"""Return positions as shape (p, 3)."""
pos = _as_array3(position)
if pos.ndim == 1:
return pos[None, :]
if pos.ndim == 2:
return pos
raise ValueError(f"Expected position shape (3,) or (p,3), got {pos.shape}.")
[docs]
def normalize_orientations(orientation: ArrayLike | None = None) -> jnp.ndarray:
"""Return orientations as shape (p, 3, 3)."""
if orientation is None:
return jnp.eye(3, dtype=jnp.float64)[None, :, :]
if hasattr(orientation, "as_matrix"):
mat = jnp.asarray(orientation.as_matrix(), dtype=jnp.float64)
else:
mat = jnp.asarray(orientation, dtype=jnp.float64)
if mat.ndim == 2:
if mat.shape != (3, 3):
raise ValueError(f"Expected orientation matrix with shape (3, 3), got {mat.shape}.")
return mat[None, :, :]
if mat.ndim == 3 and mat.shape[1:] == (3, 3):
return mat
raise ValueError(
f"Expected orientation shape (3,3), (p,3,3), or scipy Rotation; got {mat.shape}."
)
[docs]
def broadcast_pose(
*,
position: ArrayLike = (0.0, 0.0, 0.0),
orientation: ArrayLike | None = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Broadcast position/orientation path lengths with singleton expansion."""
pos = normalize_positions(position)
rot = normalize_orientations(orientation)
n_pos, n_rot = pos.shape[0], rot.shape[0]
n = max(n_pos, n_rot)
if n_pos not in (1, n):
raise ValueError(f"Incompatible position path length {n_pos} for broadcast length {n}.")
if n_rot not in (1, n):
raise ValueError(f"Incompatible orientation path length {n_rot} for broadcast length {n}.")
if n_pos == 1 and n > 1:
pos = jnp.broadcast_to(pos, (n, 3))
if n_rot == 1 and n > 1:
rot = jnp.broadcast_to(rot, (n, 3, 3))
return pos, rot
[docs]
def to_local_coordinates(
observers: ArrayLike,
*,
position: ArrayLike = (0.0, 0.0, 0.0),
orientation: ArrayLike | None = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Map global observer coordinates into source-local frame."""
obs = ensure_observers(observers)
pos = _as_array3(position)
if pos.ndim != 1:
raise ValueError(f"Expected position shape (3,), got {pos.shape}.")
rot = normalize_orientation(orientation)
obs_local = (obs - pos) @ rot
return obs_local, rot
[docs]
def to_global_field(field_local: jnp.ndarray, rotation_matrix: jnp.ndarray) -> jnp.ndarray:
"""Map local-frame vectors back to global coordinates."""
return field_local @ rotation_matrix.T
[docs]
def cart_to_cyl(observers: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Convert Cartesian coordinates to cylindrical coordinates."""
x, y, z = observers.T
r = jnp.sqrt(x * x + y * y)
phi = jnp.arctan2(y, x)
return r, phi, z
[docs]
def cyl_field_to_cart(
phi: jnp.ndarray,
hr: jnp.ndarray,
hphi_or_hz: jnp.ndarray,
hz: jnp.ndarray | None = None,
) -> jnp.ndarray:
"""Convert cylindrical field components to Cartesian field vectors.
Backward-compatible call forms:
- ``cyl_field_to_cart(phi, hr, hz)`` assumes ``Hphi=0``
- ``cyl_field_to_cart(phi, hr, hphi, hz)`` uses full cylindrical vector
"""
if hz is None:
hphi = jnp.zeros_like(hr)
hz_arr = hphi_or_hz
else:
hphi = hphi_or_hz
hz_arr = hz
cos_phi = jnp.cos(phi)
sin_phi = jnp.sin(phi)
hx = hr * cos_phi - hphi * sin_phi
hy = hr * sin_phi + hphi * cos_phi
return jnp.stack((hx, hy, hz_arr), axis=-1)