"""Collection compatibility layer for mixed source/sensor containers."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
import jax
import jax.numpy as jnp
from magpylib_jax._types import ArrayLike
from magpylib_jax.core.base import BaseGeo, BaseSource, MagpylibBadUserInput
from magpylib_jax.functional import getB, getH, getJ, getM
from magpylib_jax.sensor import Sensor
def _format_star_args(args: tuple[Any, ...]) -> Any:
if len(args) == 1:
return args[0]
return list(args)
[docs]
class Collection(BaseGeo):
"""Container for source and sensor objects with Magpylib-like behavior."""
_is_collection = True
def __init__(
self,
*children: object,
position: ArrayLike = (0.0, 0.0, 0.0),
orientation: ArrayLike | None = None,
style_label: str | None = None,
**_kwargs,
) -> None:
super().__init__(position=position, orientation=orientation, style_label=style_label)
self.children = []
self._flat_children_cache: list[object] | None = None
self._sources_cache: list[BaseSource] | None = None
self._sensors_cache: list[Sensor] | None = None
if children:
self.add(*children)
def _mark_structure_dirty(self) -> None:
self._flat_children_cache = None
self._sources_cache = None
self._sensors_cache = None
self._bump_cache_version()
parent = getattr(self, "_parent", None)
if isinstance(parent, Collection):
parent._mark_structure_dirty()
def add(self, *objects: object, override_parent: bool = False) -> Collection:
if (
len(objects) == 1
and isinstance(objects[0], Iterable)
and not isinstance(objects[0], (BaseGeo, str, bytes))
):
objects = tuple(objects[0])
for obj in objects:
if isinstance(obj, Iterable) and not isinstance(obj, (BaseGeo, str, bytes)):
self.add(*obj, override_parent=override_parent)
continue
if obj is None:
continue
if not isinstance(obj, BaseGeo):
raise MagpylibBadUserInput(
f"Cannot add object of type {type(obj).__name__!r} to Collection."
)
if isinstance(obj, Collection) and (obj is self or self in obj._flatten_children()):
msg = f"Cannot add {obj!r} because a Collection must not reference itself."
raise MagpylibBadUserInput(msg)
current_parent = getattr(obj, "_parent", None)
if current_parent is None:
obj._parent = self
elif override_parent:
if hasattr(current_parent, "remove"):
current_parent.remove(obj, errors="ignore")
obj._parent = self
else:
msg = (
f"Cannot add {obj!r} to {self!r} because it already has a parent. "
"Consider using override_parent=True."
)
raise MagpylibBadUserInput(msg)
self.children.append(obj)
if objects:
self._mark_structure_dirty()
return self
def remove(
self,
*objects: object,
recursive: bool = True,
errors: str = "raise",
) -> Collection:
if (
len(objects) == 1
and isinstance(objects[0], Iterable)
and not isinstance(objects[0], (BaseGeo, str, bytes))
):
objects = tuple(objects[0])
def _remove_from(node: Collection, target: BaseGeo) -> bool:
if target in node.children:
node.children.remove(target)
return True
if recursive:
for child in node.children:
if isinstance(child, Collection) and _remove_from(child, target):
return True
return False
for obj in objects:
if obj is None:
continue
if not isinstance(obj, BaseGeo):
if errors == "raise":
raise MagpylibBadUserInput(f"Cannot find and remove {obj!r} from {self!r}.")
if errors != "ignore":
raise MagpylibBadUserInput(
"Input errors must be one of {'raise', 'ignore'}; "
f"instead received {errors!r}."
)
continue
found = _remove_from(self, obj)
if found:
if getattr(obj, "_parent", None) is self:
obj._parent = None
else:
if errors == "raise":
raise MagpylibBadUserInput(f"Cannot find and remove {obj!r} from {self!r}.")
if errors != "ignore":
raise MagpylibBadUserInput(
"Input errors must be one of {'raise', 'ignore'}; "
f"instead received {errors!r}."
)
if objects:
self._mark_structure_dirty()
return self
def _flatten_children(self) -> list[object]:
if self._flat_children_cache is not None:
return self._flat_children_cache
out: list[object] = []
for child in self.children:
out.append(child)
if isinstance(child, Collection):
out.extend(child._flatten_children())
self._flat_children_cache = out
return out
@property
def sources(self) -> list[BaseSource]:
if self._sources_cache is None:
self._sources_cache = [
obj for obj in self._flatten_children() if isinstance(obj, BaseSource)
]
return list(self._sources_cache)
@property
def sensors(self) -> list[Sensor]:
if self._sensors_cache is None:
self._sensors_cache = [
obj for obj in self._flatten_children() if isinstance(obj, Sensor)
]
return list(self._sensors_cache)
@property
def volume(self) -> float:
return float(sum(getattr(obj, "volume", 0.0) for obj in self.sources))
@property
def centroid(self) -> jnp.ndarray:
vols = []
cents = []
for obj in self.sources:
vol = float(getattr(obj, "volume", 0.0))
if vol > 0:
vols.append(vol)
cents.append(jnp.asarray(getattr(obj, "centroid", obj.position), dtype=jnp.float64))
if not vols:
return jnp.asarray(self.position, dtype=jnp.float64)
vols_arr = jnp.asarray(vols, dtype=jnp.float64)
cents_arr = jnp.stack(cents, axis=0)
return jnp.sum(cents_arr * vols_arr[:, None], axis=0) / jnp.sum(vols_arr)
def reset_path(self) -> Collection:
for child in self.children:
if hasattr(child, "reset_path"):
child.reset_path()
super().reset_path()
return self
def _validate_getBH_inputs(self, *inputs: object):
current_sources = self.sources
current_sensors = self.sensors
if current_sensors and current_sources:
sources, sensors = self, self
if inputs:
msg = (
"Collections with sensors and sources do not allow collection.getB() inputs."
"Consider using magpy.getB() instead."
)
raise MagpylibBadUserInput(msg)
elif not current_sources:
sources, sensors = inputs, self
else:
if len(inputs) == 1:
sources, sensors = self, inputs[0]
else:
sources, sensors = self, inputs
return sources, sensors
def getB(
self,
*inputs: object,
in_out: str = "auto",
squeeze: bool = True,
sumup: bool = False,
pixel_agg: str | None = None,
output: str = "ndarray",
) -> jnp.ndarray:
sources, sensors = self._validate_getBH_inputs(*inputs)
return getB(
sources,
sensors,
in_out=in_out,
squeeze=squeeze,
sumup=sumup,
pixel_agg=pixel_agg,
output=output,
)
def getH(
self,
*inputs: object,
in_out: str = "auto",
squeeze: bool = True,
sumup: bool = False,
pixel_agg: str | None = None,
output: str = "ndarray",
) -> jnp.ndarray:
sources, sensors = self._validate_getBH_inputs(*inputs)
return getH(
sources,
sensors,
in_out=in_out,
squeeze=squeeze,
sumup=sumup,
pixel_agg=pixel_agg,
output=output,
)
def getJ(
self,
*inputs: object,
in_out: str = "auto",
squeeze: bool = True,
sumup: bool = False,
pixel_agg: str | None = None,
output: str = "ndarray",
) -> jnp.ndarray:
sources, sensors = self._validate_getBH_inputs(*inputs)
return getJ(
sources,
sensors,
in_out=in_out,
squeeze=squeeze,
sumup=sumup,
pixel_agg=pixel_agg,
output=output,
)
def getM(
self,
*inputs: object,
in_out: str = "auto",
squeeze: bool = True,
sumup: bool = False,
pixel_agg: str | None = None,
output: str = "ndarray",
) -> jnp.ndarray:
sources, sensors = self._validate_getBH_inputs(*inputs)
return getM(
sources,
sensors,
in_out=in_out,
squeeze=squeeze,
sumup=sumup,
pixel_agg=pixel_agg,
output=output,
)
def set_children_styles(self, **kwargs) -> None:
allowed = {"magnetization_show"}
invalid = [k for k in kwargs if k not in allowed]
if invalid:
raise ValueError("The following style properties are invalid: " + ", ".join(invalid))
if "magnetization_show" in kwargs:
val = kwargs["magnetization_show"]
for obj in self._flatten_children():
if hasattr(obj, "style") and hasattr(obj.style, "magnetization"):
obj.style.magnetization.show = bool(val)
def _describe_label(self, obj: object, parts: list[str]) -> str:
label = getattr(obj, "style_label", None) or getattr(obj, "style", None) and obj.style.label
label = label or None
type_name = obj.__class__.__name__
want_label = "label" in parts
want_type = "type" in parts
if want_label and want_type:
return f"{type_name} {label or 'nolabel'}"
if want_label and not want_type:
return f"{label or type_name}"
return type_name
def _describe_properties(self, obj: object) -> list[str]:
from magpylib_jax.constants import MU0
def fmt_vec(val) -> str:
return str(jax.device_get(jnp.asarray(val, dtype=jnp.float64)))
props: list[str] = []
props.append(f"position: {fmt_vec(getattr(obj, 'position', (0, 0, 0)))} m")
ori = getattr(obj, "orientation", None)
if hasattr(ori, "as_rotvec"):
rotvec = jnp.asarray(ori.as_rotvec(), dtype=jnp.float64)
else:
rotvec = jnp.zeros(3, dtype=jnp.float64)
props.append(f"orientation: {fmt_vec(jnp.rad2deg(rotvec))} deg")
dip = None
if hasattr(obj, "dipole_moment"):
dip = obj.dipole_moment
else:
pol = getattr(obj, "polarization", None)
mag = getattr(obj, "magnetization", None)
if mag is None and pol is not None:
mag = jnp.asarray(pol, dtype=jnp.float64) / MU0
if mag is not None:
dip = jnp.asarray(mag, dtype=jnp.float64) * float(getattr(obj, "volume", 0.0))
if dip is None:
dip = jnp.zeros(3, dtype=jnp.float64)
centroid = getattr(obj, "centroid", getattr(obj, "position", (0, 0, 0)))
props.append(f"centroid: {fmt_vec(centroid)}")
props.append(f"dipole_moment: {fmt_vec(dip)}")
if hasattr(obj, "polarization") or hasattr(obj, "magnetization"):
dim = getattr(obj, "dimension", None)
props.insert(2, f"dimension: {dim if dim is None else fmt_vec(dim)} m")
mag = getattr(obj, "magnetization", None)
if mag is None and getattr(obj, "polarization", None) is not None:
mag = jnp.asarray(obj.polarization, dtype=jnp.float64) / MU0
props.insert(
3,
f"magnetization: {mag if mag is None else fmt_vec(mag)} A/m",
)
pol = getattr(obj, "polarization", None)
props.insert(4, f"polarization: {pol if pol is None else fmt_vec(pol)} T")
meshing = getattr(obj, "meshing", None)
props.append(f"meshing: {meshing}")
props.append(f"volume: {float(getattr(obj, 'volume', 0.0))}")
return props
def describe(self, format: str = "label,type,id", return_string: bool = False):
fmt = format.replace(" ", "")
if fmt == "type+label":
counts = {}
for obj in self._flatten_children():
key = obj.__class__.__name__
counts[key] = counts.get(key, 0) + 1
lines = [self._describe_label(self, ["type", "label"])]
for idx, (key, count) in enumerate(counts.items()):
suffix = "s" if count != 1 else ""
lines.append(f"{'└──' if idx == len(counts) - 1 else '├──'} {count}x {key}{suffix}")
out = "\n".join(lines)
if return_string:
return out
print(out)
return None
parts = [p for p in fmt.split(",") if p]
include_properties = "properties" in parts
parts = [p for p in parts if p != "properties"]
if not parts:
parts = ["type"]
lines: list[str] = []
root_label = self._describe_label(self, parts)
if "id" in parts:
root_label += f" (id={id(self)})"
lines.append(root_label)
if include_properties:
prop_prefix = "│ " if self.children else " "
for prop in self._describe_properties(self):
lines.append(f"{prop_prefix}• {prop}")
def walk(node: Collection, prefix: str = "") -> None:
total = len(node.children)
for idx, child in enumerate(node.children):
is_last = idx == total - 1
branch = "└── " if is_last else "├── "
label = self._describe_label(child, parts)
if "id" in parts:
label += f" (id={id(child)})"
lines.append(f"{prefix}{branch}{label}")
child_prefix = prefix + (" " if is_last else "│ ")
if include_properties:
for prop in self._describe_properties(child):
lines.append(f"{child_prefix} • {prop}")
if isinstance(child, Collection):
walk(child, child_prefix)
walk(self, "")
out = "\n".join(lines)
if return_string:
return out
print(out)
return None
def _repr_html_(self) -> str:
desc = self.describe(format="label,type,id", return_string=True)
return f"<pre>{desc.replace(chr(10), '<br>')}</pre>"
def __iter__(self):
return iter(self.children)
def __len__(self) -> int:
return len(self.children)
def __getitem__(self, idx: int) -> object:
return self.children[idx]
def __add__(self, other: object) -> Collection:
return Collection(self, other)
def __repr__(self) -> str:
return super().__repr__()