Source code for bsr.geometry.composite.stack

__all__ = ["BaseStack", "RodStack", "create_rod_collection"]

from typing import TYPE_CHECKING, Any, Protocol, Type, overload

import sys

# Check python version
if sys.version_info < (3, 11):
    from typing_extensions import Self
else:
    from typing import Self

from collections.abc import Sequence

import bpy
import numpy as np
from numpy.typing import NDArray

from bsr.geometry.composite.rod import Rod
from bsr.geometry.protocol import BlenderMeshInterfaceProtocol, StackProtocol
from bsr.tools.keyframe_mixin import KeyFrameControlMixin


[docs] class BaseStack(Sequence, KeyFrameControlMixin): """ This class provides a mesh interface for a BaseStack of objects. BaseStacks are created using given positions and radii. Parameters ---------- positions: NDArray Positions of each object in the stack. Expected dimension is (n_dim, n_nodes) n_dim = 3 radii: NDArray Radii of each object in the stack. Expected dimension is (n_nodes-1,) """ DefaultType: Type def __init__(self) -> None: """ Stack class constructor """ self._objs: list[BlenderMeshInterfaceProtocol] = [] self._mats: list[BlenderMeshInterfaceProtocol] = [] @overload def __getitem__(self, index: int, /) -> BlenderMeshInterfaceProtocol: ... @overload def __getitem__( self, index: slice, / ) -> list[BlenderMeshInterfaceProtocol]: ... def __getitem__( self, index: int | slice ) -> BlenderMeshInterfaceProtocol | list[BlenderMeshInterfaceProtocol]: return self._objs[index] def __len__(self) -> int: return len(self._objs) @property def material(self) -> list[BlenderMeshInterfaceProtocol]: """ Returns the list of materials in the stack. """ return self._mats @property def object(self) -> list[BlenderMeshInterfaceProtocol]: """ Returns the list of objects in the stack. """ return self._objs
[docs] def update_keyframe(self, keyframe: int) -> None: """ Sets a keyframe at the given frame. """ for obj in self._objs: obj.update_keyframe(keyframe)
[docs] @classmethod def create( cls, states: dict[str, NDArray], ) -> Self: """ Basic factory method to create a new BaseStack of objects. States must have the following keys: positions(n_dim, n_nodes), radii(n_nodes-1,) Parameters ---------- states: dict[str, NDArray] A dictionary where keys are state names and values are NDarrays. Returns ------- Self An instance of the BaseStack with objects containing the states Raises ------ AssertionError If the states have differing lengths """ self = cls() keys = states.keys() lengths = [i.shape[0] for i in states.values()] assert len(set(lengths)) <= 1, "All states must have the same length" num_objects = lengths[0] for oidx in range(num_objects): state = {k: v[oidx] for k, v in states.items()} obj = self.DefaultType.create(state) self._objs.append(obj) self._mats.append(obj.material) return self
[docs] def update_states(self, *variables: NDArray) -> None: """ Updates the states of the BaseStack objects. Parameters ---------- *variables: NDArray An array including all the state updates of the object in the stack. Expected dimension is (n_nodes - 1,) """ if not all([v.shape[0] == len(self) for v in variables]): raise IndexError( "All variables must have the same length as the stack" ) for idx in range(len(self)): self[idx].update_states(*[v[idx] for v in variables])
[docs] def update_material(self, **kwargs: dict[str, NDArray]) -> None: """ Updates the material of the BaseStack objects Parameters ---------- kwargs : dict Keyword arguments for the material update """ for material_key, material_values in kwargs.items(): assert isinstance( material_values, np.ndarray ), "Values of kwargs must be a numpy array" if material_values.shape[0] != len(self): raise IndexError( "All values must have the same length as the stack" ) for idx in range(len(self)): self[idx].update_material({material_key: material_values[idx]})
[docs] class RodStack(BaseStack): """ This class provides a mesh interface for a RodStack of objects (only contains Rod objects). RodStacks are created using given positions and radii. Parameters ---------- positions: NDArray Positions of each Rod in the stack. Expected dimension is (n_dim, n_nodes). n_dim = 3 radii: NDArray Radii of each Rod in the stack. Expected dimension is (n_nodes-1,). """ input_states = {"positions", "radii"} DefaultType: Type = Rod
# Alias for factory functions create_rod_collection = RodStack.create if TYPE_CHECKING: data: dict[str, NDArray] = { "positions": np.array([[[0, 0, 0], [1, 1, 1]]]), "radii": np.array([[1.0, 1.0]]), } _: StackProtocol = RodStack.create(data)