diff --git a/.gitignore b/.gitignore index 038815d..b9c4bca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ *.db -.venv/ \ No newline at end of file +*.csv +.mypy_cache/ +.venv/ +.python-version \ No newline at end of file diff --git a/README.md b/README.md index 58251df..134956f 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,8 @@ This will display detailed usage information. - `-t`, `--api_timeout`: Timeout (in seconds) for API requests (default: 10). - `-i`, `--poll_interval`: Interval (in seconds) between API status polls (default: 5). - `-p`, `--parallel_call_count`: Number of parallel API calls (default: 10). +- `--csv_report`: Path to export the detailed report as a CSV file. +- `--db_path`: Path where the SQlite DB file is stored (default: './file_processing.db') - `--retry_failed`: Retry processing of failed files. - `--retry_pending`: Retry processing of pending files by making new requests. - `--skip_pending`: Skip processing of pending files. @@ -67,7 +69,6 @@ This will display detailed usage information. - `--print_report`: Print a detailed report of all processed files at the end. - `--exclude_metadata`: Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file. - `--no_verify`: Disable SSL certificate verification. (By default, SSL verification is enabled.) -- `--csv_report`: Path to export the detailed report as a CSV file. ## Usage Examples diff --git a/main.py b/main.py index ce10472..de7f11a 100644 --- a/main.py +++ b/main.py @@ -16,8 +16,6 @@ from tqdm import tqdm from unstract.api_deployments.client import APIDeploymentsClient -DB_NAME = "file_processing.db" -global_arguments = None logger = logging.getLogger(__name__) @@ -29,6 +27,7 @@ class Arguments: api_timeout: int = 10 poll_interval: int = 5 input_folder_path: str = "" + db_path: str = "" parallel_call_count: int = 5 retry_failed: bool = False retry_pending: bool = False @@ -42,8 +41,8 @@ class Arguments: # Initialize SQLite DB -def init_db(): - conn = sqlite3.connect(DB_NAME) +def init_db(args: Arguments): + conn = sqlite3.connect(args.db_path) c = conn.cursor() # Create the table if it doesn't exist @@ -89,7 +88,7 @@ def init_db(): # Check if the file is already processed def skip_file_processing(file_name, args: Arguments): - conn = sqlite3.connect(DB_NAME) + conn = sqlite3.connect(args.db_path) c = conn.cursor() c.execute( "SELECT execution_status FROM file_status WHERE file_name = ?", (file_name,) @@ -124,6 +123,7 @@ def update_db( time_taken, status_code, status_api_endpoint, + args: Arguments ): total_embedding_cost = None @@ -138,7 +138,7 @@ def update_db( if execution_status == "ERROR": error_message = extract_error_message(result) - conn = sqlite3.connect(DB_NAME) + conn = sqlite3.connect(args.db_path) conn.set_trace_callback( lambda x: ( logger.debug(f"[{file_name}] Executing statement: {x}") @@ -232,8 +232,8 @@ def extract_error_message(result): return result.get("error", "No error message found") # Print final summary with count of each status and average time using a single SQL query -def print_summary(): - conn = sqlite3.connect(DB_NAME) +def print_summary(args: Arguments): + conn = sqlite3.connect(args.db_path) c = conn.cursor() # Fetch count and average time for each status @@ -255,8 +255,8 @@ def print_summary(): print(f"Status '{status}': {count}") -def print_report(): - conn = sqlite3.connect(DB_NAME) +def print_report(args: Arguments): + conn = sqlite3.connect(args.db_path) c = conn.cursor() # Fetch required fields, including total_cost and total_tokens @@ -318,13 +318,13 @@ def print_report(): print("\nNote: For more detailed error messages, use the CSV report argument.") -def export_report_to_csv(output_path): - conn = sqlite3.connect(DB_NAME) +def export_report_to_csv(args: Arguments): + conn = sqlite3.connect(args.db_path) c = conn.cursor() c.execute( """ - SELECT file_name, execution_status, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message + SELECT file_name, execution_status, result, time_taken, total_embedding_cost, total_embedding_tokens, total_llm_cost, total_llm_tokens, error_message FROM file_status """ ) @@ -332,22 +332,22 @@ def export_report_to_csv(output_path): conn.close() if not report_data: - print("No data available to export.") + print("No data available to export as CSV.") return # Define the headers headers = [ - "File Name", "Execution Status", "Time Elapsed (seconds)", + "File Name", "Execution Status", "Result", "Time Elapsed (seconds)", "Total Embedding Cost", "Total Embedding Tokens", "Total LLM Cost", "Total LLM Tokens", "Error Message" ] try: - with open(output_path, 'w', newline='') as csvfile: + with open(args.csv_report, 'w', newline='') as csvfile: writer = csv.writer(csvfile) writer.writerow(headers) # Write headers writer.writerows(report_data) # Write data rows - print(f"CSV successfully exported to {output_path}") + print(f"CSV successfully exported to '{args.csv_report}'") except Exception as e: print(f"Error exporting to CSV: {e}") @@ -357,7 +357,7 @@ def get_status_endpoint(file_path, client, args: Arguments): status_endpoint = None # If retry_pending is True, check if the status API endpoint is available - conn = sqlite3.connect(DB_NAME) + conn = sqlite3.connect(args.db_path) c = conn.cursor() c.execute( "SELECT status_api_endpoint FROM file_status WHERE file_name = ? AND execution_status NOT IN ('COMPLETED', 'ERROR')", @@ -382,7 +382,7 @@ def get_status_endpoint(file_path, client, args: Arguments): # Fresh API call to process the file execution_status = "STARTING" - update_db(file_path, execution_status, None, None, None, None) + update_db(file_path, execution_status, None, None, None, None, args=args) response = client.structure_file(file_paths=[file_path]) logger.debug(f"[{file_path}] Response of initial API call: {response}") status_endpoint = response.get( @@ -397,6 +397,7 @@ def get_status_endpoint(file_path, client, args: Arguments): None, status_code, status_endpoint, + args=args ) return status_endpoint, execution_status, response @@ -436,7 +437,7 @@ def process_file( execution_status = response.get("execution_status") status_code = response.get("status_code") # Default to 200 if not provided update_db( - file_path, execution_status, None, None, status_code, status_endpoint + file_path, execution_status, None, None, status_code, status_endpoint, args=args ) result = response @@ -456,7 +457,7 @@ def process_file( end_time = time.time() time_taken = round(end_time - start_time, 2) update_db( - file_path, execution_status, result, time_taken, status_code, status_endpoint + file_path, execution_status, result, time_taken, status_code, status_endpoint, args=args ) logger.info(f"[{file_path}]: Processing completed: {execution_status}") @@ -501,14 +502,14 @@ def load_folder(args: Arguments): def main(): - parser = argparse.ArgumentParser(description="Process files using the API.") + parser = argparse.ArgumentParser(description="Process files using Unstract's API deployment") parser.add_argument( "-e", "--api_endpoint", dest="api_endpoint", type=str, required=True, - help="API Endpoint to use for processing the files.", + help="API Endpoint to use for processing the files", ) parser.add_argument( "-k", @@ -524,7 +525,7 @@ def main(): dest="api_timeout", type=int, default=10, - help="Time in seconds to wait before switching to async mode.", + help="Time in seconds to wait before switching to async mode (default: 10)", ) parser.add_argument( "-i", @@ -532,7 +533,7 @@ def main(): dest="poll_interval", type=int, default=5, - help="Time in seconds the process will sleep between polls in async mode.", + help="Time in seconds the process will sleep between polls in async mode (default: 5)", ) parser.add_argument( "-f", @@ -540,7 +541,7 @@ def main(): dest="input_folder_path", type=str, required=True, - help="Path where the files to process are present.", + help="Path where the files to process are present", ) parser.add_argument( "-p", @@ -548,31 +549,44 @@ def main(): dest="parallel_call_count", type=int, default=5, - help="Number of calls to be made in parallel.", + help="Number of calls to be made in parallel (default: 5)", + ) + parser.add_argument( + "--db_path", + dest="db_path", + type=str, + default="file_processing.db", + help="Path where the SQlite DB file is stored (default: './file_processing.db)'", + ) + parser.add_argument( + '--csv_report', + dest="csv_report", + type=str, + help='Path to export the detailed report as a CSV file', ) parser.add_argument( "--retry_failed", dest="retry_failed", action="store_true", - help="Retry processing of failed files.", + help="Retry processing of failed files (default: True)", ) parser.add_argument( "--retry_pending", dest="retry_pending", action="store_true", - help="Retry processing of pending files as new request (Without this it will try to fetch the results using status API).", + help="Retry processing of pending files as new request (Without this it will try to fetch the results using status API) (default: True)", ) parser.add_argument( "--skip_pending", dest="skip_pending", action="store_true", - help="Skip processing of pending files (Over rides --retry-pending).", + help="Skip processing of pending files (overrides --retry-pending) (default: True)", ) parser.add_argument( "--skip_unprocessed", dest="skip_unprocessed", action="store_true", - help="Skip unprocessed files while retry processing of failed files.", + help="Skip unprocessed files while retry processing of failed files (default: True)", ) parser.add_argument( "--log_level", @@ -586,52 +600,47 @@ def main(): "--print_report", dest="print_report", action="store_true", - help="Print a detailed report of all file processed.", + help="Print a detailed report of all file processed (default: True)", ) - parser.add_argument( "--exclude_metadata", dest="include_metadata", action="store_false", - help="Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file.", + help="Exclude metadata on tokens consumed and the context passed to LLMs for prompt studio exported tools in the result for each file (default: False)", ) - parser.add_argument( "--no_verify", dest="verify", action="store_false", - help="Disable SSL certificate verification.", - ) - - parser.add_argument( - '--csv_report', - dest="csv_report", - type=str, - help='Path to export the detailed report as a CSV file', + help="Disable SSL certificate verification (default: False)", ) args = Arguments(**vars(parser.parse_args())) ch = logging.StreamHandler(sys.stdout) ch.setLevel(args.log_level) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + ch.setFormatter(formatter) logging.basicConfig(level=args.log_level, handlers=[ch]) logger.warning(f"Running with params: {args}") - init_db() # Initialize DB + init_db(args=args) # Initialize DB load_folder(args=args) - print_summary() # Print summary at the end + print_summary(args=args) # Print summary at the end if args.print_report: - print_report() + print_report(args=args) logger.warning( "Elapsed time calculation of a file which was resumed" " from pending state will not be correct" ) if args.csv_report: - export_report_to_csv(args.csv_report) + export_report_to_csv(args=args) if __name__ == "__main__":