diff --git a/src/beanmachine/ppl/compiler/fix_requirements.py b/src/beanmachine/ppl/compiler/fix_requirements.py index 1e97136f5c..788524d291 100644 --- a/src/beanmachine/ppl/compiler/fix_requirements.py +++ b/src/beanmachine/ppl/compiler/fix_requirements.py @@ -103,13 +103,11 @@ def _node_meets_requirement(self, node: bn.BMGNode, r: bt.Requirement) -> bool: ) return self._type_meets_requirement(lattice_type, r) - def _meet_constant_requirement( + def _try_to_meet_constant_requirement( self, node: bn.ConstantNode, requirement: bt.Requirement, - consumer: bn.BMGNode, - edge: str, - ) -> bn.BMGNode: + ) -> Optional[bn.BMGNode]: # We have a constant node that either (1) is untyped, and therefore # needs to be replaced by an equivalent typed node, or (2) is typed # but is of the wrong type, and needs to be replaced by an equivalent @@ -162,13 +160,26 @@ def _meet_constant_requirement( result = self.bmg.add_constant_of_type(node.value, required_type) assert self._node_meets_requirement(result, requirement) return result + return None + + def _meet_constant_requirement( + self, + node: bn.ConstantNode, + requirement: bt.Requirement, + consumer: bn.BMGNode, + edge: str, + ) -> bn.BMGNode: + + result = self._try_to_meet_constant_requirement(node, requirement) + if result is not None: + return result # We cannot convert this node to any type that meets the requirement. # Add an error. self.errors.add_error( Violation( node, - it, + self._typer[node], requirement, consumer, edge, @@ -415,13 +426,13 @@ def _try_to_force_to_neg_real(self, node, requirement) -> Optional[bn.BMGNode]: return self.bmg.add_to_negative_real(node) - def _meet_operator_requirement( + def _try_to_meet_operator_requirement( self, node: bn.OperatorNode, requirement: bt.Requirement, consumer: bn.BMGNode, edge: str, - ) -> bn.BMGNode: + ) -> Optional[bn.BMGNode]: # We should not have called this function if the input node already meets # the requirement on the edge. @@ -470,14 +481,30 @@ def _meet_operator_requirement( if result is not None: return result - node_type = self._typer[node] - result = self._try_to_force_to_neg_real(node, requirement) if result is not None: return result - # Those are the only techniques we have to make an operator meet a requirement. - # We have no way to make the conversion we need, so add an error. + # We couldn't meet the requirement. + + return None + + def _meet_operator_requirement( + self, + node: bn.OperatorNode, + requirement: bt.Requirement, + consumer: bn.BMGNode, + edge: str, + ) -> bn.BMGNode: + assert not self._node_meets_requirement(node, requirement) + result = self._try_to_meet_operator_requirement( + node, requirement, consumer, edge + ) + if result is not None: + return result + + # We were unable to meet a requirement; add an error. + node_type = self._typer[node] self.errors.add_error( Violation( node,