Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions montepy/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,13 @@ def num(obj):
for key in keys:
attr = getattr(self, key)
setattr(result, key, copy.deepcopy(attr, memo))
# Clear weakrefs so the cloned cell isn't linked to original collection/problem
# This prevents number conflict checks against the original collection
result._collection_ref = None
result._problem_ref = None
# Update the geometry's _cell references to point to result, not the deepcopied intermediate
if result._geometry is not None:
result._geometry._set_cell(result)
# copy geometry
for special in special_keys:
new_objs = []
Expand Down
11 changes: 10 additions & 1 deletion montepy/mcnp_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
except AttributeError:
self._BLOCK_TYPE = montepy.input_parser.block_type.BlockType.DATA
self._problem_ref = None
self._collection_ref = None
self._parameters = ParametersNode()
self._input = None
if input:
Expand Down Expand Up @@ -225,6 +226,13 @@ def _update_values(self):
"""
pass

@property
def _collection(self):
"""Returns the parent collection this object belongs to, if any."""
if self._collection_ref is not None:
return self._collection_ref()
return None

def format_for_mcnp_input(self, mcnp_version: tuple[int]) -> list[str]:
"""Creates a list of strings representing this MCNP_Object that can be
written to file.
Expand Down Expand Up @@ -490,14 +498,15 @@ def _grab_beginning_comment(self, padding: list[PaddingNode], last_obj=None):

def __getstate__(self):
state = self.__dict__.copy()
bad_keys = {"_problem_ref", "_parser"}
bad_keys = {"_problem_ref", "_collection_ref", "_parser"}
for key in bad_keys:
if key in state:
del state[key]
return state

def __setstate__(self, crunchy_data):
crunchy_data["_problem_ref"] = None
crunchy_data["_collection_ref"] = None
self.__dict__.update(crunchy_data)

def clone(self) -> montepy.mcnp_object.MCNP_Object:
Expand Down
42 changes: 26 additions & 16 deletions montepy/numbered_mcnp_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
from typing import Union
from numbers import Integral
import weakref

from montepy.mcnp_object import MCNP_Object, InitInput
import montepy
Expand All @@ -14,22 +15,31 @@
def _number_validator(self, number):
if number < 0:
raise ValueError("number must be >= 0")
if self._problem:
obj_map = montepy.MCNP_Problem._NUMBERED_OBJ_MAP
try:
collection_type = obj_map[type(self)]
except KeyError as e:
found = False
for obj_class in obj_map:
if isinstance(self, obj_class):
collection_type = obj_map[obj_class]
found = True
break
if not found:
raise e
collection = getattr(self._problem, collection_type.__name__.lower())
collection.check_number(number)
collection._update_number(self.number, number, self)

# Only validate against collection if linked to a problem
if self._problem is not None:
if self._collection is not None:
collection = self._collection
else:
# Find collection via _problem
obj_map = montepy.MCNP_Problem._NUMBERED_OBJ_MAP
collection_type = obj_map.get(type(self))

if collection_type is None:
# Finding via inheritance
for obj_class in obj_map:
if isinstance(self, obj_class):
collection_type = obj_map[obj_class]
break

if collection_type is not None:
collection = getattr(self._problem, collection_type.__name__.lower())
else:
collection = None

if collection is not None:
collection.check_number(number)
collection._update_number(self.number, number, self)


class Numbered_MCNP_Object(MCNP_Object):
Expand Down
26 changes: 22 additions & 4 deletions montepy/numbered_object_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
)
)
self.__num_cache[obj.number] = obj
obj._collection_ref = weakref.ref(self)
self._objects = objects

def link_to_problem(self, problem):
Expand All @@ -196,8 +197,18 @@ def link_to_problem(self, problem):
self._problem_ref = None
else:
self._problem_ref = weakref.ref(problem)
# Rebuild the number cache to ensure it reflects current object numbers
# after deepcopy/unpickling when the cache might be stale
self.__num_cache.clear()
for obj in self._objects:
self.__num_cache[obj.number] = obj
for obj in self:
obj.link_to_problem(problem)
# the _collection_ref that points to the main cells collection.
if problem is not None:
existing_coll = obj._collection
if existing_coll is None or existing_coll._problem is not problem:
obj._collection_ref = weakref.ref(self)

@property
def _problem(self):
Expand Down Expand Up @@ -340,11 +351,15 @@ def extend(self, other_list):
"The object in the list {obj} is not of type: {self._obj_class}"
)
if obj.number in nums:
raise NumberConflictError(
(
f"When adding to {type(self).__name__} there was a number collision due to "
f"adding {obj} which conflicts with {self[obj.number]}"
try:
conflicting_obj = self[obj.number]
conflict_msg = (
f"adding {obj} which conflicts with {conflicting_obj}"
)
except KeyError:
conflict_msg = f"adding {obj} which conflicts with existing object number {obj.number}"
raise NumberConflictError(
f"When adding to {type(self).__name__} there was a number collision due to {conflict_msg}"
)
nums.add(obj.number)
for obj in other_list:
Expand Down Expand Up @@ -496,6 +511,7 @@ def __internal_append(self, obj, **kwargs):
)
self.__num_cache[obj.number] = obj
self._objects.append(obj)
obj._collection_ref = weakref.ref(self)
self._append_hook(obj, **kwargs)
if self._problem:
obj.link_to_problem(self._problem)
Expand All @@ -507,6 +523,7 @@ def __internal_delete(self, obj, **kwargs):
"""
self.__num_cache.pop(obj.number, None)
self._objects.remove(obj)
obj._collection_ref = None
self._delete_hook(obj, **kwargs)

def add(self, obj: Numbered_MCNP_Object):
Expand Down Expand Up @@ -613,6 +630,7 @@ def append_renumber(self, obj, step=1):
number = obj.number if obj.number > 0 else 1
if self._problem:
obj.link_to_problem(self._problem)
obj._collection_ref = None
try:
self.append(obj)
except (NumberConflictError, ValueError) as e:
Expand Down
29 changes: 29 additions & 0 deletions montepy/surfaces/half_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,22 @@ def update_pointers(self, cells, surfaces, cell):
if self.right is not None:
self.right.update_pointers(cells, surfaces, cell)

def _set_cell(self, cell):
"""Sets the _cell reference for this HalfSpace and all children.

This is used during cloning to update the _cell references to point
to the cloned cell rather than the original (or a deepcopied intermediate).

Parameters
----------
cell : Cell
the cell this HalfSpace should be tied to.
"""
self._cell = cell
self.left._set_cell(cell)
if self.right is not None:
self.right._set_cell(cell)

def _add_new_children_to_cell(self, other):
"""Adds the cells and surfaces from a new tree to this parent cell.

Expand Down Expand Up @@ -676,6 +692,19 @@ def update_pointers(self, cells, surfaces, cell):
"Cell", self._cell.number, "Surface", self._divider
)

def _set_cell(self, cell):
"""Sets the _cell reference for this UnitHalfSpace.

This is used during cloning to update the _cell references to point
to the cloned cell rather than the original (or a deepcopied intermediate).

Parameters
----------
cell : Cell
the cell this UnitHalfSpace should be tied to.
"""
self._cell = cell

def _ensure_has_nodes(self):
if self.node is None:
if isinstance(self.divider, Integral):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_numbered_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_extend(self, cp_simple_problem):
extender = copy.deepcopy(extender)
for surf in extender:
surf._problem = None
surf._collection_ref = None
surfaces[1000].number = 1
extender[0].number = 1000
extender[1].number = 70
Expand Down Expand Up @@ -161,6 +162,7 @@ def test_append_renumber(self, cp_simple_problem):
cells.append_renumber(cell, "hi")
cell = copy.deepcopy(cell)
cell._problem = None
cell._collection_ref = None
cell.number = 1
cells.append_renumber(cell)
assert cell.number == 4
Expand Down
Loading