Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rl_prompt_tuning #13

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 135 additions & 5 deletions nomadic/experiment/prompt_tuning.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
# prompt_tuner.py

from typing import List, Optional, Dict, Union, Any
from itertools import product
import time
import re
import json
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import dspy

# Import the OpenAI class as per your code
from openai import OpenAI # Ensure this import aligns with your environment

Expand Down Expand Up @@ -44,7 +40,7 @@ def __init__(
self.optimizer_type = optimizer_type

# Initialize the optimizer if DSPy optimization is enabled
if self.use_iterative_optimization:
if self.use_iterative_optimization == "dspy":
if self.optimizer_type == "BootstrapFewShotWithRandomSearch":
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
# Configure the BootstrapFewShotWithRandomSearch optimizer
Expand Down Expand Up @@ -73,6 +69,8 @@ def generate_prompt_variants(self, client, user_prompt_request, max_retries: int
return self.generate_prompt_variants_optimized(client, user_prompt_request, max_retries, retry_delay)
elif self.use_iterative_optimization == "iterative":
return self.generate_prompt_variants_iterative(client, user_prompt_request, max_retries, retry_delay)
elif self.use_iterative_optimization == "rl_agent":
return self.generate_prompt_variants_rl_agent(client, user_prompt_request, max_retries, retry_delay)
else:
return self.generate_prompt_variants_static(client, user_prompt_request, max_retries, retry_delay)

Expand Down Expand Up @@ -207,6 +205,67 @@ def objective_function(**params):

return prompt_variants

def generate_prompt_variants_rl_agent(self, client, user_prompt_request, max_retries: int = 3, retry_delay: int = 5):
"""
Generate prompt variants using an iterative optimization method with an RL agent.

This method uses a reinforcement learning agent to iteratively refine prompt variants based on
a defined reward function, which is the evaluation score of the prompt. The agent continues to
improve the prompt until an optimal score is achieved or the maximum iterations are reached.

Args:
client: The client object for API calls.
user_prompt_request: List of prompt templates to be optimized.
max_retries: Maximum number of retries for API calls.
retry_delay: Delay between retries.

Returns:
List of optimized prompt variants, one for each input template.
"""
prompt_variants = []
max_iterations = 50 # Maximum number of iterations for the RL agent
epsilon = 0.01 # Threshold for convergence

for template in user_prompt_request:
current_prompt = template
best_prompt = current_prompt
best_score = float('-inf')
iteration = 0

while iteration < max_iterations:
if self.enable_logging:
print(f"\nIteration {iteration + 1}/{max_iterations}")
print(f"Current Prompt:\n{current_prompt}\n")

# State: Current prompt
state = current_prompt

# Evaluate the current prompt to get the reward
reward = self.evaluate_variant(state)
if self.enable_logging:
print(f"Evaluation Score: {reward}")

# Update the best prompt if the reward is better
if reward > best_score:
best_score = reward
best_prompt = state

# Check for convergence
if reward >= 1.0 - epsilon:
if self.enable_logging:
print("Optimal prompt achieved.")
break

# Action: Modify the prompt based on feedback
feedback = self.get_feedback(state)
action = self.modify_prompt_with_feedback(client, state, feedback, max_retries, retry_delay)

# Update the prompt (next state)
current_prompt = action
iteration += 1

prompt_variants.append(best_prompt)

return prompt_variants

def generate_prompt_variant(self, client, user_prompt_request, max_retries, retry_delay):
Expand Down Expand Up @@ -510,6 +569,77 @@ def evaluate_variant(self, variant):
# If no evaluation dataset, return a default score
return 0.0

def get_feedback(self, prompt):
"""
Generate feedback for the given prompt.

Args:
prompt: The prompt to evaluate.

Returns:
Feedback string based on evaluation.
"""
# For simplicity, we can use the evaluation explanation as feedback
evaluation_metrics = [
{"metric": "Accuracy", "weight": 1.0}
]
openai_api_key = None # Replace with your OpenAI API key if necessary
evaluation_result = custom_evaluate(
response=prompt,
evaluation_metrics=evaluation_metrics,
openai_api_key=openai_api_key,
)
feedback = evaluation_result.get('explanation', '')
return feedback

def modify_prompt_with_feedback(self, client, prompt, feedback, max_retries, retry_delay):
"""
Modify the prompt based on the feedback.

Args:
client: The client object for API calls.
prompt: The current prompt.
feedback: Feedback string to guide the modification.
max_retries: Maximum number of retries.
retry_delay: Delay between retries.

Returns:
Modified prompt.
"""
system_message = f"""
You are an expert prompt engineer. Improve the following prompt based on the feedback provided.

Original Prompt:
{prompt}

Feedback:
{feedback}

Provide the improved prompt, ensuring it addresses the issues mentioned.
"""

for attempt in range(max_retries):
try:
response = client.chat.completions.create(
model=DEFAULT_OPENAI_MODEL,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": "Provide the improved prompt."},
],
temperature=0.7,
)
new_prompt = response.choices[0].message.content.strip()
return new_prompt
except Exception as e:
if self.enable_logging:
print(f"Error modifying prompt: {e}")
print(f"Retrying... ({attempt + 1}/{max_retries})")
time.sleep(retry_delay)

if self.enable_logging:
print(f"Failed to modify prompt after {max_retries} attempts.")
return prompt # Return the original prompt if unable to modify

def update_params(self, params: Dict[str, Any]):
if "prompt_tuning_approach" in params:
self.prompt_tuning_approaches = [params["prompt_tuning_approach"]]
Expand Down
Loading