Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add flow_id, self.name for instance names of Flows and Nodes and easy isolated storage for nested flows #5

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 78 additions & 26 deletions pocketflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
import asyncio, warnings, copy, time
import asyncio, warnings, copy, time, sys

class BaseNode:
def __init__(self): self.params,self.successors={},{}
def __init__(self):
self.params, self.successors = {}, {}
self.name = self.get_instance_name() or f"node_{hash(self)}"
self.flow = None # Will be set by Flow._propagate_flow
self.parent = None # Will be set by Flow._propagate_flow

def get_instance_name(self):
"""Find the variable name this instance is assigned to, if any"""
try:
frame = sys._getframe(1)
while frame:
for scope in (frame.f_locals, frame.f_globals):
for key, value in scope.items():
if value is self and not key.startswith('_') and key != 'self':
return key
frame = frame.f_back
except (AttributeError, ValueError):
pass
return None

def _get_name(self):
"""Return the instance name, either from name attribute or lookup"""
return self.name or self.get_instance_name() or f"node_{hash(self)}"
def set_params(self,params): self.params=params
def add_successor(self,node,action="default"):
if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
Expand All @@ -11,8 +33,8 @@ def exec(self,prep_res): pass
def post(self,shared,prep_res,exec_res): pass
def _exec(self,prep_res): return self.exec(prep_res)
def _run(self,shared): p=self.prep(shared);e=self._exec(p);return self.post(shared,p,e)
def run(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use Flow.")
def run(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use Flow.")
return self._run(shared)
def __rshift__(self,other): return self.add_successor(other)
def __sub__(self,action):
Expand All @@ -37,22 +59,45 @@ class BatchNode(Node):
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in items]

class Flow(BaseNode):
def __init__(self,start): super().__init__();self.start=start
def __init__(self, start, name=None):
super().__init__()
self.start = start
self.name = name or self.get_instance_name() or f"flow_{hash(self)}"
self._propagate_flow(self.start)

def _propagate_flow(self, node, visited=None):
"""Set flow and parent references on all nodes in the flow"""
if visited is None:
visited = set()

if node is None or id(node) in visited:
return

visited.add(id(node))
node.flow = self
node.parent = self.parent if hasattr(self, 'parent') else None

for successor in node.successors.values():
self._propagate_flow(successor, visited)
def get_next_node(self,curr,action):
nxt=curr.successors.get(action or "default")
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
return nxt
def _orch(self,shared,params=None):
curr,p=copy.copy(self.start),(params or {**self.params})
while curr: curr.set_params(p);c=curr._run(shared);curr=copy.copy(self.get_next_node(curr,c))
def _orch(self, shared, params=None):
curr, p = copy.copy(self.start), (params or {**self.params})
while curr:
curr.set_params(p)
c = curr._run(shared)
curr = copy.copy(self.get_next_node(curr, c))
def _run(self,shared): pr=self.prep(shared);self._orch(shared);return self.post(shared,pr,None)
def exec(self,prep_res): raise RuntimeError("Flow can't exec.")

class BatchFlow(Flow):
def _run(self,shared):
pr=self.prep(shared) or []
for bp in pr: self._orch(shared,{**self.params,**bp})
return self.post(shared,pr,None)
def _run(self, shared):
pr = self.prep(shared) or []
for bp in pr:
self._orch(shared, {**self.params, **bp})
return self.post(shared, pr, None)

class AsyncNode(Node):
def prep(self,shared): raise RuntimeError("Use prep_async.")
Expand All @@ -64,14 +109,14 @@ async def prep_async(self,shared): pass
async def exec_async(self,prep_res): pass
async def exec_fallback_async(self,prep_res,exc): raise exc
async def post_async(self,shared,prep_res,exec_res): pass
async def _exec(self,prep_res):
async def _exec(self,prep_res):
for i in range(self.max_retries):
try: return await self.exec_async(prep_res)
except Exception as e:
if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e)
if self.wait>0: await asyncio.sleep(self.wait)
async def run_async(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
async def run_async(self,shared):
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
return await self._run_async(shared)
async def _run_async(self,shared): p=await self.prep_async(shared);e=await self._exec(p);return await self.post_async(shared,p,e)

Expand All @@ -82,19 +127,26 @@ class AsyncParallelBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))

class AsyncFlow(Flow,AsyncNode):
async def _orch_async(self,shared,params=None):
curr,p=copy.copy(self.start),(params or {**self.params})
while curr:curr.set_params(p);c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared);curr=copy.copy(self.get_next_node(curr,c))
async def _orch_async(self, shared, params=None):
curr, p = copy.copy(self.start), (params or {**self.params})
while curr:
curr.set_params(p)
if isinstance(curr, AsyncNode):
c = await curr._run_async(shared)
else:
c = curr._run(shared)
curr = copy.copy(self.get_next_node(curr, c))
async def _run_async(self,shared): p=await self.prep_async(shared);await self._orch_async(shared);return await self.post_async(shared,p,None)

class AsyncBatchFlow(AsyncFlow,BatchFlow):
async def _run_async(self,shared):
pr=await self.prep_async(shared) or []
for bp in pr: await self._orch_async(shared,{**self.params,**bp})
return await self.post_async(shared,pr,None)
async def _run_async(self, shared):
pr = await self.prep_async(shared) or []
for bp in pr:
await self._orch_async(shared, {**self.params, **bp})
return await self.post_async(shared, pr, None)

class AsyncParallelBatchFlow(AsyncFlow,BatchFlow):
async def _run_async(self,shared):
pr=await self.prep_async(shared) or []
await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr))
return await self.post_async(shared,pr,None)
async def _run_async(self, shared):
pr = await self.prep_async(shared) or []
await asyncio.gather(*(self._orch_async(shared, {**self.params, **bp}) for bp in pr))
return await self.post_async(shared, pr, None)
1 change: 1 addition & 0 deletions pocketflow/example.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fasdasdasd
237 changes: 237 additions & 0 deletions pocketflow/rework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@

# from pocketflow import *
from __init__ import *
import os

def call_llm(prompt):
# Your API logic here
return prompt

class LoadFile(Node):
def __init__(self, name=None):
super().__init__()
# Use provided name or fall back to automatic lookup
self.name = name or self.name
def prep(self, shared):
print(f" In : {self.__class__.__name__}")
"""Load file from disk"""
filename = self.params["filename"]
with open(filename, "r") as file:
return file.read()

def exec(self, prep_res):
"""Return file content"""
return prep_res

def post(self, shared, prep_res, exec_res):
"""Store file content in shared"""
shared["file_content"] = exec_res
return "default"


class GetOpinion(Node):
def __init__(self, name=None):
super().__init__()
# Use provided name or fall back to automatic lookup
self.name = name or self.name

def prep(self, shared):
print(f" In : {self.__class__.__name__}")
print(f"My name is: {self.name} (instance of {self.__class__.__name__})")
if self.flow:
print(f"Flow name: {self.flow.name}")
if self.flow.parent:
print(f"Parent flow: {self.flow.parent.name}")
"""Get file content from shared"""
if not shared.get("reworked_file_content"):
return shared["file_content"]
else:
return "Original text :\n" + shared["file_content"] + "Revised version:\n" + shared["reworked_file_content"]

def exec(self, prep_res):
"""Ask LLM for opinion on file content"""
prompt = f"What's your opinion on this text: {prep_res}. Provide opinion on how to make it better."
return call_llm(prompt)

def post(self, shared, prep_res, exec_res):
"""Store opinion in shared"""
shared["opinion"] = exec_res
return "default"

class GetValidation(Node):
def __init__(self, name=None):
super().__init__()
# Use provided name or fall back to automatic lookup
self.name = name or self.name
def prep(self, shared):
print(f" In : {self.__class__.__name__}")
"""Get file content from shared"""
shared['discussion'] = shared["file_content"] + shared["opinion"] + "Final revised text : " + shared["reworked_file_content"]
return

def exec(self, prep_res):
"""Ask LLM for opinion on file content"""
prompt = f"Validate that the final revised text is valid and reflects the changes proposed in opinion : {prep_res}. \nReply `IS VALID` if it is of `NOT VALID` if it needs some more work."
return call_llm(prompt)

def post(self, shared, prep_res, exec_res):
"""Store rework count in shared"""
if "IS VALID" in exec_res:
return "default"
else:
return "invalid"


class ReworkFile(Node):
def __init__(self, name=None):
super().__init__()
# Use provided name or fall back to automatic lookup
self.name = name or self.name
def prep(self, shared):
print(f" In : {self.__class__.__name__}")
"""Get file content and opinion from shared"""
return shared["file_content"], shared["opinion"]

def exec(self, prep_res):
"""Ask LLM to rework file based on opinion"""
file_content, opinion = prep_res
prompt = f"Rework this text based on the opinion: {opinion}\n\nOriginal text: {file_content}"
return call_llm(prompt)

def post(self, shared, prep_res, exec_res):
"""Store reworked file content in shared"""
if "rework2_flow_min_count" in self.params:
rework_count = self.params["rework2_flow_min_count"]
shared["reworked_file_content"] = exec_res
if not shared.get("reworked_file_content_count"):
shared["reworked_file_content_count"] = 1
elif shared.get("reworked_file_content_count"):
shared["reworked_file_content_count"] += 1

if shared["reworked_file_content_count"] < rework_count:
print(f"Less than {self.params["rework2_flow_min_count"]} rework for rework2_flow, so going for pass #{shared["reworked_file_content_count"]}.")
return "rework"
else:
return "default"
else:
shared["reworked_file_content"] = exec_res


class SaveFile(Node):
def __init__(self, name=None):
super().__init__()
# Use provided name or fall back to automatic lookup
self.name = name or self.name
def prep(self, shared):
print(f" In : {self.__class__.__name__}")
"""Get reworked file content and original filename from shared"""
filename = self.params["filename"]
if "reworked_file_content" in shared:
return shared["reworked_file_content"], filename
else:
print("Error")

def exec(self, prep_res):
"""Save reworked file content to new file"""
reworked_file_content, filename = prep_res
new_filename = f"{filename.split('.')[0]}_v2.{filename.split('.')[-1]}"
with open(new_filename, "w") as file:
file.write(reworked_file_content)
return reworked_file_content

def post(self, shared, prep_res, exec_res):
filename = self.params["filename"]
"""Return success message"""
print(f"Saved to {filename} the content : \n{exec_res}")


# # # Comment this from here
# # First flow
# Create nodes
load_Node = LoadFile(name="load_Node")
opinion_Node = GetOpinion(name="opinion_Node")
rework_Node = ReworkFile(name="rework_Node")
save_Node = SaveFile(name="save_Node")

# Connect nodes
load_Node >> opinion_Node >> rework_Node >> save_Node

# Create flow
rework_Flow = Flow(start=load_Node,name="rework_Flow")

# Set flow params
rework_Flow.set_params({"filename": "example.txt"})
# Run flow
shared = {}
rework_Flow.run(shared)
# # # To here for second workflow to work

# # Second flow
# Create nodes
load2_Node = LoadFile(name="load2_Node")
opinion2_Node = GetOpinion(name="opinion2_Node")
rework2_Node = ReworkFile(name="rework2_Node")
valid2_Node = GetValidation(name="valid2_Node")
save2_Node = SaveFile(name="save2_Node")

print(f" NAME is : {opinion2_Node.name}")

# Connect nodes
load2_Node >> opinion2_Node
opinion2_Node >> rework2_Node

rework2_Node - "default" >> valid2_Node
rework2_Node - "rework" >> opinion2_Node

# Get second opinion it if rework asked because in rework_flow2 and less than 2 rework
valid2_Node - "invalid" >> opinion2_Node
valid2_Node - "default" >> save2_Node

# Create flow with explicit name
rework2_Flow = Flow(start=load2_Node, name="rework2_Flow")
# rework2_Flow.name = "rework2_Flow" # Set explicit name

# Set flow params
# This will not set params if class Flow was already initialized with other params ?
rework2_Flow.set_params({"filename": "example.txt", "rework2_flow_min_count" : 3})

# Run flow
shared2 = {}
rework2_Flow.run(shared2)

def build_mermaid(start):
visited, lines = set(), ["graph LR"]

def get_name(n):
"""Get the node's name for use in the diagram"""
if isinstance(n, Flow):
return n._get_name()
return n._get_name().replace(' ', '_') # Mermaid needs no spaces in node names

def link(a, b):
lines.append(f" {get_name(a)} --> {get_name(b)}")
def walk(node, parent=None):
if node in visited:
return parent and link(parent, node)
visited.add(node)
if isinstance(node, Flow):
node.start and parent and link(parent, node)
# Add flow name and class name to subgraph label
flow_label = f"{node._get_name()} ({type(node).__name__})"
lines.append(f"\n subgraph {get_name(node)}[\"{flow_label}\"]")
node.start and walk(node.start)
for nxt in node.successors.values():
node.start and walk(nxt, node.start) or (parent and link(parent, nxt)) or walk(nxt)
lines.append(" end\n")
else:
# Add both instance name and class name to node label
node_label = f"{node._get_name()} ({type(node).__name__})"
lines.append(f" {get_name(node)}[\"{node_label}\"]")
parent and link(parent, node)
[walk(nxt, node) for nxt in node.successors.values()]
walk(start)
return "\n".join(lines)

print(build_mermaid(start=rework_Flow))

print(build_mermaid(start=rework2_Flow))