Source code for magpylib_jax.magnet.triangular_mesh

"""Differentiable triangular-mesh magnet source."""

from __future__ import annotations

import warnings

import jax.numpy as jnp
import numpy as np

from magpylib_jax._types import ArrayLike
from magpylib_jax.constants import MU0
from magpylib_jax.core.base import BaseSource, MagpylibMissingInput
from magpylib_jax.functional import getB, getH, getJ, getM


[docs] class TriangularMesh(BaseSource): """Uniformly polarized magnet defined by mesh vertices and triangular faces.""" _source_type = "triangularmesh" def __init__( self, vertices: ArrayLike | None = None, faces: ArrayLike | None = None, polarization: ArrayLike | None = None, magnetization: ArrayLike | None = None, position: ArrayLike = (0.0, 0.0, 0.0), orientation: ArrayLike | None = None, reorient_faces: bool = True, check_open: bool | str = "warn", in_out: str = "auto", style=None, style_label: str | None = None, **kwargs, ) -> None: self.vertices = vertices self.faces = faces self.polarization = polarization self.magnetization = magnetization self.reorient_faces = reorient_faces self.check_open = check_open self.in_out = self._validate_in_out(in_out) self.meshing = None self.status_open: bool | None = None self.status_open_data: list[tuple[int, int]] | None = None self.status_disconnected: bool | None = None self.status_disconnected_data: list[object] | None = None self.status_reoriented: bool | None = bool(reorient_faces) self.status_selfintersecting: bool | None = None self.status_selfintersecting_data: object | None = None if self.vertices is not None and self.faces is not None: verts = jnp.asarray(self.vertices, dtype=jnp.float64) facs = jnp.asarray(self.faces, dtype=jnp.int32) if verts.ndim != 2 or verts.shape[1] != 3: raise ValueError("TriangularMesh `vertices` must have shape (n,3).") if facs.ndim != 2 or facs.shape[1] != 3: raise ValueError("TriangularMesh `faces` must have shape (m,3).") if jnp.any(facs < 0) or jnp.any(facs >= verts.shape[0]): raise ValueError("TriangularMesh `faces` contain indices outside `vertices`.") mode = self._validate_mode_arg(self.check_open, arg_name="check_open") if mode != "skip": open_edges = self._get_open_edges(np.asarray(self.vertices), np.asarray(self.faces)) self.status_open = len(open_edges) > 0 self.status_open_data = [tuple(e) for e in open_edges.tolist()] self.status_disconnected = False self.status_disconnected_data = [None] if self.status_open: msg = ( "Open mesh detected in TriangularMesh. Inside-outside checks and " "reorient_faces may yield unexpected results for open meshes. " "This check can be disabled at initialization with check_open='skip'." ) if mode == "warn": warnings.warn(msg, stacklevel=2) elif mode == "raise": raise ValueError(msg) super().__init__( position=position, orientation=orientation, style=style, style_label=style_label, **kwargs, ) self._mesh_cache_key: tuple[object, ...] | None = None self._faces_oriented_cache = jnp.zeros((0, 3), dtype=jnp.int32) self._mesh_cache = jnp.zeros((0, 3, 3), dtype=jnp.float64) def _geometry_cache_key(self) -> tuple[object, ...]: return ( id(self.vertices), id(self.faces), bool(self.reorient_faces), ) def _geometry_cache(self) -> tuple[jnp.ndarray, jnp.ndarray]: if self.vertices is None or self.faces is None: empty_faces = jnp.zeros((0, 3), dtype=jnp.int32) empty_mesh = jnp.zeros((0, 3, 3), dtype=jnp.float64) self._mesh_cache_key = None self._faces_oriented_cache = empty_faces self._mesh_cache = empty_mesh return empty_faces, empty_mesh cache_key = self._geometry_cache_key() if self._mesh_cache_key == cache_key: return self._faces_oriented_cache, self._mesh_cache faces = jnp.asarray(self.faces, dtype=jnp.int32) if self.reorient_faces: verts = jnp.asarray(self.vertices, dtype=jnp.float64) tri = verts[faces] center = jnp.mean(verts, axis=0) ctri = jnp.mean(tri, axis=1) nvec = jnp.cross(tri[:, 1] - tri[:, 0], tri[:, 2] - tri[:, 0]) inward = jnp.sum(nvec * (ctri - center), axis=1) < 0 faces = jnp.where(inward[:, None], faces[:, (0, 2, 1)], faces) verts = jnp.asarray(self.vertices, dtype=jnp.float64) mesh = verts[faces] self._mesh_cache_key = cache_key self._faces_oriented_cache = faces self._mesh_cache = mesh return faces, mesh @staticmethod def _validate_mode_arg(arg: bool | str, *, arg_name: str = "mode") -> str: accepted = (True, False, "warn", "raise", "ignore", "skip") if arg not in accepted: msg = ( "Input " f"{arg_name} must be one of {{'warn', 'raise', 'ignore', 'skip', True, False}}; " f"instead received {arg!r}." ) raise ValueError(msg) return "warn" if arg is True else "skip" if arg is False else arg @staticmethod def _validate_in_out(value: str) -> str: if value not in ("auto", "inside", "outside"): raise ValueError( "TriangularMesh `in_out` must be one of {'auto', 'inside', 'outside'}." ) return value @staticmethod def _get_open_edges(vertices: np.ndarray, faces: np.ndarray) -> np.ndarray: edges = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0) edges = np.sort(edges, axis=1) uniq, counts = np.unique(edges, axis=0, return_counts=True) return uniq[counts == 1] @classmethod def from_ConvexHull(cls, points: ArrayLike, **kwargs) -> TriangularMesh: from scipy.spatial import ConvexHull opts = dict(kwargs) opts.pop("check_selfintersecting", None) opts.pop("check_disconnected", None) pts = np.asarray(points, dtype=float) hull = ConvexHull(pts) return cls(vertices=hull.points, faces=hull.simplices, **opts) @property def _polarization(self) -> jnp.ndarray: if self.polarization is not None: return jnp.asarray(self.polarization, dtype=jnp.float64) if self.magnetization is not None: return MU0 * jnp.asarray(self.magnetization, dtype=jnp.float64) raise MagpylibMissingInput("Input polarization of TriangularMesh must be set.") @property def _faces_oriented(self) -> jnp.ndarray: faces, _ = self._geometry_cache() return faces @property def mesh(self) -> jnp.ndarray: _, mesh = self._geometry_cache() return mesh @property def barycenter(self) -> jnp.ndarray: if self.vertices is None or self.faces is None: return jnp.zeros((3,), dtype=jnp.float64) tri = self.mesh ctri = jnp.mean(tri, axis=1) area = 0.5 * jnp.linalg.norm( jnp.cross(tri[:, 1] - tri[:, 0], tri[:, 2] - tri[:, 0]), axis=1, ) total = jnp.maximum(jnp.sum(area), 1e-30) return jnp.sum(ctri * area[:, None], axis=0) / total @property def centroid(self) -> jnp.ndarray: return self.barycenter + jnp.asarray(self.position, dtype=jnp.float64) @property def volume(self) -> float: if self.vertices is None or self.faces is None: return 0.0 tri = self.mesh signed = jnp.sum(jnp.cross(tri[:, 0], tri[:, 1]) * tri[:, 2], axis=1) / 6.0 return float(jnp.abs(jnp.sum(signed))) def _require_inputs(self) -> None: if self.vertices is None: raise MagpylibMissingInput("Input vertices of TriangularMesh must be set.") if self.faces is None: raise MagpylibMissingInput("Input faces of TriangularMesh must be set.") if self.polarization is None and self.magnetization is None: raise MagpylibMissingInput("Input polarization of TriangularMesh must be set.") def getB( self, *observers: ArrayLike, in_out: str | None = None, squeeze: bool = True, sumup: bool = False, output: str = "ndarray", pixel_agg: str | None = None, ) -> jnp.ndarray: self._require_inputs() if in_out is None: in_out = self.in_out in_out = self._validate_in_out(in_out) obs = observers[0] if len(observers) == 1 else list(observers) return getB( "triangularmesh", obs, mesh=self.mesh, polarization=self._polarization, position=self.position, orientation=self.orientation, in_out=in_out, squeeze=squeeze, sumup=sumup, output=output, pixel_agg=pixel_agg, ) def getH( self, *observers: ArrayLike, in_out: str | None = None, squeeze: bool = True, sumup: bool = False, output: str = "ndarray", pixel_agg: str | None = None, ) -> jnp.ndarray: self._require_inputs() if in_out is None: in_out = self.in_out in_out = self._validate_in_out(in_out) obs = observers[0] if len(observers) == 1 else list(observers) return getH( "triangularmesh", obs, mesh=self.mesh, polarization=self._polarization, position=self.position, orientation=self.orientation, in_out=in_out, squeeze=squeeze, sumup=sumup, output=output, pixel_agg=pixel_agg, ) def getJ( self, *observers: ArrayLike, in_out: str | None = None, squeeze: bool = True, sumup: bool = False, ) -> jnp.ndarray: self._require_inputs() if in_out is None: in_out = self.in_out in_out = self._validate_in_out(in_out) obs = observers[0] if len(observers) == 1 else list(observers) return getJ( "triangularmesh", obs, mesh=self.mesh, polarization=self._polarization, position=self.position, orientation=self.orientation, in_out=in_out, squeeze=squeeze, sumup=sumup, ) def getM( self, *observers: ArrayLike, in_out: str | None = None, squeeze: bool = True, sumup: bool = False, ) -> jnp.ndarray: self._require_inputs() if in_out is None: in_out = self.in_out in_out = self._validate_in_out(in_out) obs = observers[0] if len(observers) == 1 else list(observers) return getM( "triangularmesh", obs, mesh=self.mesh, polarization=self._polarization, position=self.position, orientation=self.orientation, in_out=in_out, squeeze=squeeze, sumup=sumup, ) def copy(self, **kwargs) -> TriangularMesh: return super().copy(**kwargs)