Skip to content

feat: switch to postgres for database tools #533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions database/go.mod
Original file line number Diff line number Diff line change
@@ -1,24 +1,3 @@
module obot-platform/database

go 1.23.3

require (
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250222170845-eee4337500a6
github.com/ncruces/go-sqlite3 v0.20.3
)

require (
github.com/getkin/kin-openapi v0.129.0 // indirect
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/mailru/easyjson v0.9.0 // indirect
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 // indirect
github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/tetratelabs/wazero v1.8.2 // indirect
golang.org/x/sys v0.27.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
51 changes: 0 additions & 51 deletions database/go.sum
Original file line number Diff line number Diff line change
@@ -1,51 +0,0 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/getkin/kin-openapi v0.129.0 h1:QGYTNcmyP5X0AtFQ2Dkou9DGBJsUETeLH9rFrJXZh30=
github.com/getkin/kin-openapi v0.129.0/go.mod h1:gmWI+b/J45xqpyK5wJmRRZse5wefA5H0RDMK46kLUtI=
github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ=
github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY=
github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE=
github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ=
github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM=
github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250222170845-eee4337500a6 h1:vsZ09cWfNWUXT6AOVQc1GpfEdIxcLusUs6Hgo9IgAKs=
github.com/gptscript-ai/go-gptscript v0.9.6-0.20250222170845-eee4337500a6/go.mod h1:QvGPZoRuAiA8P5EzPI05kTrs+LZ0ipHywUGsKruSknw=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw=
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
github.com/ncruces/go-sqlite3 v0.20.3 h1:+4G4uEqOeusF0yRuQVUl9fuoEebUolwQSnBUjYBLYIw=
github.com/ncruces/go-sqlite3 v0.20.3/go.mod h1:ojLIAB243gtz68Eo283Ps+k9PyR3dvzS+9/RgId4+AA=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 h1:nZspmSkneBbtxU9TopEAE0CY+SBJLxO8LPUlw2vG4pU=
github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80/go.mod h1:7tFDb+Y51LcDpn26GccuUgQXUk6t0CXZsivKjyimYX8=
github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349 h1:t05Ww3DxZutOqbMN+7OIuqDwXbhl32HiZGpLy26BAPc=
github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o=
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.8.2 h1:yIgLR/b2bN31bjxwXHD8a3d+BogigR952csSDdLYEv4=
github.com/tetratelabs/wazero v1.8.2/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
150 changes: 29 additions & 121 deletions database/main.go
Original file line number Diff line number Diff line change
@@ -2,89 +2,60 @@ package main

import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
"slices"

"obot-platform/database/pkg/cmd"

"github.com/gptscript-ai/go-gptscript"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"os"
)

var workspaceID = os.Getenv("DATABASE_WORKSPACE_ID")

func main() {
if len(os.Args) != 2 {
fmt.Println("Usage: gptscript-go-tool <command>")
os.Exit(1)
}
command := os.Args[1]
ctx := context.Background()

g, err := gptscript.NewGPTScript()
if err != nil {
fmt.Printf("Error creating GPTScript: %v\n", err)
os.Exit(1)
workspaceID := os.Getenv("DATABASE_WORKSPACE_ID")
if workspaceID == "" {
// TODO(njhale): Figure out why DATABASE_WORKSPACE_ID is not set here for the UI tools.
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
}
defer g.Close()

var (
ctx = context.Background()
dbFileName = "obot.db"
dbWorkspacePath = "/databases/" + dbFileName
revisionID string = "-1"
initialDBData []byte
)

workspaceDB, err := g.ReadFileWithRevisionInWorkspace(ctx, dbWorkspacePath, gptscript.ReadFileInWorkspaceOptions{
WorkspaceID: workspaceID,
})

var notFoundErr *gptscript.NotFoundInWorkspaceError
if err != nil && !errors.As(err, &notFoundErr) {
fmt.Printf("Error reading DB file: %v\n", err)
os.Exit(1)
}
// Get admin DSN from environment variable
adminDSN := os.Getenv("POSTGRES_DSN")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will need to get plumbed down as a credential from obot when we decide how postgres for this tool is going to be stood up/configured.


// Create a temporary file for the SQLite database
dbFile, err := os.CreateTemp("", dbFileName)
// Setup database and user with admin credentials
dsn, err := cmd.EnsureTenantSchema(ctx, adminDSN, workspaceID)
if err != nil {
fmt.Printf("Error creating temp file: %v\n", err)
fmt.Printf("Error setting up database: %v\n", err)
os.Exit(1)
}
defer dbFile.Close()
defer os.Remove(dbFile.Name())

// Write the data to the temporary file
if workspaceDB != nil && workspaceDB.Content != nil {
initialDBData = workspaceDB.Content
if err := os.WriteFile(dbFile.Name(), initialDBData, 0644); err != nil {
fmt.Printf("Error writing to temp file: %v\n", err)
os.Exit(1)
}
if workspaceDB.RevisionID != "" {
revisionID = workspaceDB.RevisionID
}
}

// Run the requested command
// Run the requested command using the user credentials
var result string
switch command {
case "listDatabaseTables":
result, err = cmd.ListDatabaseTables(ctx, dbFile)
result, err = cmd.ListDatabaseTables(ctx, dsn)

case "listDatabaseTableRows":
result, err = cmd.ListDatabaseTableRows(ctx, dbFile, os.Getenv("TABLE"))
table := os.Getenv("TABLE")
if table == "" {
err = fmt.Errorf("TABLE environment variable is required")
break
}
result, err = cmd.ListDatabaseTableRows(ctx, dsn, table)

case "runDatabaseSQL":
result, err = cmd.RunDatabaseCommand(ctx, dbFile, os.Getenv("SQL"), "-header")
if err == nil {
err = saveWorkspaceDB(ctx, g, dbWorkspacePath, revisionID, dbFile, initialDBData)
sql := os.Getenv("SQL")
if sql == "" {
err = fmt.Errorf("SQL environment variable is required")
break
}
result, err = cmd.RunDatabaseCommand(ctx, dsn, sql)

case "databaseContext":
result, err = cmd.DatabaseContext(ctx, dbFile)
result, err = cmd.DatabaseContext(ctx, dsn)

default:
err = fmt.Errorf("unknown command: %s", command)
}
@@ -96,66 +67,3 @@ func main() {

fmt.Print(result)
}

// saveWorkspaceDB saves the updated database file to the workspace if the content of the database has changed.
func saveWorkspaceDB(
ctx context.Context,
g *gptscript.GPTScript,
dbWorkspacePath string,
revisionID string,
dbFile *os.File,
initialDBData []byte,
) error {
updatedDBData, err := os.ReadFile(dbFile.Name())
if err != nil {
return fmt.Errorf("Error reading updated DB file: %v", err)
}

if hash(initialDBData) == hash(updatedDBData) {
return nil
}

if err := g.WriteFileInWorkspace(ctx, dbWorkspacePath, updatedDBData, gptscript.WriteFileInWorkspaceOptions{
WorkspaceID: workspaceID,
CreateRevision: &([]bool{true}[0]),
LatestRevisionID: revisionID,
}); err != nil {
return fmt.Errorf("Error writing updated DB file to workspace: %v", err)
}

// Delete old revisions after successfully writing the new revision
revisions, err := g.ListRevisionsForFileInWorkspace(ctx, dbWorkspacePath, gptscript.ListRevisionsForFileInWorkspaceOptions{
WorkspaceID: workspaceID,
})
if err != nil {
fmt.Fprintf(os.Stderr, "Error listing revisions: %v\n", err)
return nil
}

lastRevisionIndex := slices.IndexFunc(revisions, func(rev gptscript.FileInfo) bool {
return rev.RevisionID == revisionID
})

if lastRevisionIndex < 0 {
return nil
}

for _, rev := range revisions[:lastRevisionIndex+1] {
if err := g.DeleteRevisionForFileInWorkspace(ctx, dbWorkspacePath, rev.RevisionID, gptscript.DeleteRevisionForFileInWorkspaceOptions{
WorkspaceID: workspaceID,
}); err != nil {
fmt.Fprintf(os.Stderr, "Error deleting revision %s: %v\n", rev.RevisionID, err)
}
}

return nil
}

// hash computes the SHA-256 hash of the given data and returns it as a hexadecimal string
func hash(data []byte) string {
if data == nil {
return ""
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:])
}
143 changes: 125 additions & 18 deletions database/pkg/cmd/command.go
Original file line number Diff line number Diff line change
@@ -3,39 +3,146 @@ package cmd
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
)

// RunDatabaseCommand runs a sqlite3 command against the database and returns the output from the sqlite3 CLI.
func RunDatabaseCommand(ctx context.Context, dbFile *os.File, sql string, opts ...string) (string, error) {
// Remove the "sqlite3" prefix and trim whitespace
args := append(opts, dbFile.Name())
if arg := strings.TrimSpace(sql); arg != "" {
// Use strconv.Unquote to safely handle quotes and escape sequences
unquoted, err := strconv.Unquote(arg)
if err != nil {
// If unquoting fails (e.g. string wasn't quoted), use original
unquoted = arg
}
args = append(args, unquoted)
// RunDatabaseCommand executes a command against the Postgres database
func RunDatabaseCommand(ctx context.Context, dsn string, sql string, opts ...string) (string, error) {
if sql == "" {
return "", fmt.Errorf("SQL cannot be empty")
}

args := append([]string{dsn}, opts...)

unquoted, err := strconv.Unquote(sql)
if err != nil {
unquoted = sql
}
args = append(args, "-c", unquoted)

// Build the sqlite3 command
cmd := exec.CommandContext(ctx, "sqlite3", args...)
cmd := exec.CommandContext(ctx, "psql", args...)

// Redirect command output
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr

// Run the command and capture errors
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("error executing sqlite3: %w, stderr: %s", err, stderr.String())
return "", fmt.Errorf("psql error: %w\nstderr: %s", err, stderr.String())
}

if stderr.Len() > 0 {
return stdout.String(), fmt.Errorf("psql stderr: %s", stderr.String())
}

return stdout.String(), nil
}

// EnsureTenantSchema creates a schema and role for a tenant with proper isolation
func EnsureTenantSchema(ctx context.Context, adminDSN, workspaceID string) (string, error) {
schemaName := workspaceSchemaName(workspaceID)
userName := schemaName
password := generatePassword(workspaceID)
dbName := "obot_db"

// Create shared database if it doesn't exist
checkDBSQL := fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", dbName)
dbExistsCheck, err := RunDatabaseCommand(ctx, adminDSN, checkDBSQL, "-At")
if err != nil {
return "", fmt.Errorf("error checking for shared database: %w", err)
}
if strings.TrimSpace(dbExistsCheck) != "1" {
createDBSQL := fmt.Sprintf("CREATE DATABASE %s", dbName)
if _, err := RunDatabaseCommand(ctx, adminDSN, createDBSQL); err != nil {
return "", fmt.Errorf("error creating shared database: %w", err)
}

// Create tenant role
createRoleSQL := fmt.Sprintf(`CREATE ROLE %s WITH LOGIN PASSWORD '%s'`, userName, password)
if _, err := RunDatabaseCommand(ctx, adminDSN, createRoleSQL); err != nil && !strings.Contains(err.Error(), "already exists") {
return "", fmt.Errorf("error creating role: %w", err)
}

// Connect to shared database
dbDSN, err := dsnWithDatabase(adminDSN, dbName)
if err != nil {
return "", fmt.Errorf("error constructing DSN for shared database: %w", err)
}

// Create schema for tenant (owned by admin user)
createSchemaSQL := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, schemaName)
if _, err := RunDatabaseCommand(ctx, dbDSN, createSchemaSQL); err != nil {
return "", fmt.Errorf("error creating schema: %w", err)
}

// Revoke PUBLIC access on public schema
revokePublicSchemaSQL := `REVOKE ALL ON SCHEMA public FROM PUBLIC`
if _, err := RunDatabaseCommand(ctx, dbDSN, revokePublicSchemaSQL); err != nil {
return "", fmt.Errorf("error revoking public schema privileges: %w", err)
}

// Set up tenant schema permissions
statements := []string{
fmt.Sprintf(`REVOKE ALL ON SCHEMA %s FROM PUBLIC`, schemaName),
fmt.Sprintf(`GRANT USAGE, CREATE ON SCHEMA %s TO %s`, schemaName, userName),
fmt.Sprintf(`GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s TO %s`, schemaName, userName),
fmt.Sprintf(`GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA %s TO %s`, schemaName, userName),
fmt.Sprintf(`GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA %s TO %s`, schemaName, userName),
fmt.Sprintf(`ALTER DEFAULT PRIVILEGES IN SCHEMA %s GRANT ALL ON TABLES TO %s`, schemaName, userName),
fmt.Sprintf(`ALTER DEFAULT PRIVILEGES IN SCHEMA %s GRANT ALL ON SEQUENCES TO %s`, schemaName, userName),
fmt.Sprintf(`ALTER DEFAULT PRIVILEGES IN SCHEMA %s GRANT ALL ON FUNCTIONS TO %s`, schemaName, userName),
fmt.Sprintf(`ALTER ROLE %s SET search_path = %s`, userName, schemaName),
}

for _, stmt := range statements {
if _, err := RunDatabaseCommand(ctx, dbDSN, stmt); err != nil {
return "", fmt.Errorf("error executing statement '%s': %w", stmt, err)
}
}
}

userDSN := fmt.Sprintf("postgresql://%s:%s@%s/%s?sslmode=require", userName, password, extractHost(adminDSN), dbName)
return userDSN, nil
}

// generatePassword creates a hashed password using the workspaceID
func generatePassword(workspaceID string) string {
hash := sha256.Sum256([]byte(workspaceID))
return hex.EncodeToString(hash[:])
}

// extractHost extracts the host part from the DSN
func extractHost(dsn string) string {
re := regexp.MustCompile(`^(postgresql://[^:]+:[^@]+@)([^/]+)(/[^?]*)(\?.+)?$`)
matches := re.FindStringSubmatch(dsn)
if len(matches) >= 3 {
return matches[2]
}
return ""
}

// workspaceSchemaName converts a workspace ID into a valid PostgreSQL schema/role identifier
func workspaceSchemaName(workspaceID string) string {
hash := sha256.Sum256([]byte(workspaceID))
return "schema_" + hex.EncodeToString(hash[:16])
}

// dsnWithDatabase switches the database in the DSN string
func dsnWithDatabase(adminDSN, dbName string) (string, error) {
if strings.HasPrefix(adminDSN, "postgresql://") {
re := regexp.MustCompile(`^(postgresql://[^/]+/)([^?]*)(\?.+)?$`)
matches := re.FindStringSubmatch(adminDSN)
if len(matches) >= 3 {
if matches[3] != "" {
return matches[1] + dbName + matches[3], nil
}
return matches[1] + dbName, nil
}
}
return "", fmt.Errorf("invalid DSN format")
}
115 changes: 83 additions & 32 deletions database/pkg/cmd/context.go
Original file line number Diff line number Diff line change
@@ -3,52 +3,103 @@ package cmd
import (
"context"
"fmt"
"os"
"strings"
)

// DatabaseContext generates a markdown-formatted string with instructions
// and the database's current schemas.
func DatabaseContext(ctx context.Context, dbFile *os.File) (string, error) {
const getSchemasSQL = `
WITH table_columns AS (
SELECT
table_name,
ordinal_position,
column_name,
data_type,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_schema = 'public'
),
constraints AS (
SELECT
conname,
contype,
conrelid::regclass::text AS table_name,
pg_get_constraintdef(oid, true) AS definition
FROM pg_constraint
WHERE connamespace = 'public'::regnamespace
),
indexes AS (
SELECT
tablename,
indexdef
FROM pg_indexes
WHERE schemaname = 'public'
)
SELECT format(
E'\nCREATE TABLE %I (\n%s%s\n);\n\n%s\n',
tc.table_name,
tc.table_name,
string_agg(
format(' %I %s%s%s',
tc.column_name,
tc.data_type,
CASE WHEN tc.column_default IS NOT NULL THEN ' DEFAULT ' || tc.column_default ELSE '' END,
CASE WHEN tc.is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END
),
E',\n'
ORDER BY tc.ordinal_position
),
CASE
WHEN ct.constraint_defs IS NOT NULL THEN E',\n' || ct.constraint_defs
ELSE ''
END,
COALESCE(idx.index_defs, '')
)
FROM table_columns tc
LEFT JOIN (
SELECT
table_name,
string_agg(
format(' CONSTRAINT %I %s', conname, definition),
E',\n'
) AS constraint_defs
FROM constraints
GROUP BY table_name
) ct ON tc.table_name = ct.table_name
LEFT JOIN (
SELECT
tablename,
string_agg(indexdef, E'\n') AS index_defs
FROM indexes
GROUP BY tablename
) idx ON tc.table_name = idx.tablename
GROUP BY tc.table_name, ct.constraint_defs, idx.index_defs
ORDER BY tc.table_name;
`

// DatabaseContext returns markdown with database schema information
func DatabaseContext(ctx context.Context, dsn string) (string, error) {
var builder strings.Builder

// Add usage instructions
builder.WriteString(`# START INSTRUCTIONS: Run Database SQL tool
builder.WriteString(`# PostgreSQL Database Tool
You have access to tools for interacting with a SQLite database.
The "Run Database SQL" tool lets you run SQL against the SQLite3 database.
You have access to tools for interacting with a PostgreSQL database.
The "Run Database SQL" tool lets you run SQL against the PostgreSQL database.
Display all results from these tools and their schemas in markdown format.
If the user refers to creating or modifying tables, assume they mean a SQLite3 table and not writing a table in a markdown file.
If the user refers to creating or modifying tables, assume they mean a PostgreSQL table and not writing a table in a markdown file.
# END INSTRUCTIONS: Run Database SQL tool
`)

// Add the schemas section
schemas, err := getSchemas(ctx, dbFile)
schemas, err := RunDatabaseCommand(ctx, dsn, getSchemasSQL, "-At")
if err != nil {
return "", fmt.Errorf("failed to retrieve schemas: %w", err)
return "", fmt.Errorf("error getting schemas: %w", err)
}
if schemas != "" {
builder.WriteString("# START CURRENT DATABASE SCHEMAS\n")
builder.WriteString(schemas)
builder.WriteString("\n# END CURRENT DATABASE SCHEMAS\n")

if schemas == "" {
builder.WriteString("\n# No tables found in database\n")
} else {
builder.WriteString("# DATABASE HAS NO TABLES\n")
builder.WriteString("\n# Database Schema\n\n")
builder.WriteString(schemas)
}

return builder.String(), nil
}

// getSchemas retrieves all schemas from the database using the sqlite3 CLI.
func getSchemas(ctx context.Context, dbFile *os.File) (string, error) {
query := `SELECT sql FROM sqlite_master WHERE type IN ('table', 'index', 'view', 'trigger') AND name NOT LIKE 'sqlite_%' ORDER BY name;`

// Execute the query using the RunDatabaseCommand function
output, err := RunDatabaseCommand(ctx, dbFile, query)
if err != nil {
return "", fmt.Errorf("error querying schemas: %w", err)
}

// Return raw output as-is
return strings.TrimSpace(output), nil
}
74 changes: 18 additions & 56 deletions database/pkg/cmd/rows.go
Original file line number Diff line number Diff line change
@@ -2,67 +2,29 @@ package cmd

import (
"context"
"encoding/json"
"fmt"
"os"
)

type Output struct {
Columns []string `json:"columns"`
Rows []map[string]any `json:"rows"`
}

// ListDatabaseTableRows lists all rows from the specified table using RunDatabaseCommand and returns the JSON output directly.
func ListDatabaseTableRows(ctx context.Context, dbFile *os.File, table string) (string, error) {
const tableRowsSQL = `
SELECT json_build_object(
'columns', (
SELECT array_agg(column_name ORDER BY ordinal_position)
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = '%s'
),
'rows', COALESCE((
SELECT json_agg(row_to_json(t))
FROM %s t
), '[]')
)::text;
`

// ListDatabaseTableRows returns table contents with columns
func ListDatabaseTableRows(ctx context.Context, dsn string, table string) (string, error) {
if table == "" {
return "", fmt.Errorf("table name cannot be empty")
}

// Get column names using PRAGMA
columnsQuery := fmt.Sprintf("PRAGMA table_info(%q);", table)
columnsOutput, err := RunDatabaseCommand(ctx, dbFile, columnsQuery, "-json")
if err != nil {
return "", fmt.Errorf("error getting columns for table %q: %w", table, err)
}

// Parse column information
var columnInfo []struct {
Name string `json:"name"`
}
if err := json.Unmarshal([]byte(columnsOutput), &columnInfo); err != nil {
return "", fmt.Errorf("error parsing column information: %w", err)
}

columns := make([]string, len(columnInfo))
for i, col := range columnInfo {
columns[i] = col.Name
}

// Get all rows
rowsQuery := fmt.Sprintf("SELECT * FROM %q;", table)
rowsOutput, err := RunDatabaseCommand(ctx, dbFile, rowsQuery, "-json")
if err != nil {
return "", fmt.Errorf("error executing query for table %q: %w", table, err)
}

// Parse rows
var rows []map[string]any
if rowsOutput != "" {
if err := json.Unmarshal([]byte(rowsOutput), &rows); err != nil {
return "", fmt.Errorf("error parsing JSON output: %w", err)
}
}

// Create and marshal output
output := Output{
Columns: columns,
Rows: rows,
}

result, err := json.Marshal(output)
if err != nil {
return "", fmt.Errorf("error marshaling output: %w", err)
}

return string(result), nil
query := fmt.Sprintf(tableRowsSQL, table, table)
return RunDatabaseCommand(ctx, dsn, query, "-At")
}
47 changes: 17 additions & 30 deletions database/pkg/cmd/table.go
Original file line number Diff line number Diff line change
@@ -2,42 +2,29 @@ package cmd

import (
"context"
"encoding/json"
"fmt"
"os"
)

type tables struct {
Tables []Table `json:"tables"`
}

type Table struct {
Name string `json:"name,omitempty"`
}

// ListDatabaseTables returns a JSON string containing the list of tables in the database.
func ListDatabaseTables(ctx context.Context, dbFile *os.File) (string, error) {
// Query to fetch table names
query := "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"

// Execute the query using RunDatabaseCommand with JSON output
output, err := RunDatabaseCommand(ctx, dbFile, query, "-json")
const listTablesSQL = `SELECT COALESCE(
json_agg(json_build_object('name', table_name)), '[]'
)::text
FROM (
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'
ORDER BY table_name
) AS ordered_tables;`

// ListDatabaseTables returns a JSON string containing the list of tables
func ListDatabaseTables(ctx context.Context, dsn string) (string, error) {
output, err := RunDatabaseCommand(ctx, dsn, listTablesSQL, "-At")
if err != nil {
return "", fmt.Errorf("error executing query to list tables: %w", err)
}

var dbTables tables
if output != "" {
if err := json.Unmarshal([]byte(output), &(dbTables.Tables)); err != nil {
return "", fmt.Errorf("error parsing table names: %w", err)
}
return "", fmt.Errorf("error listing tables: %w", err)
}

// Marshal final result
data, err := json.Marshal(dbTables)
if err != nil {
return "", fmt.Errorf("error marshaling tables to JSON: %w", err)
if output == "" {
return `{"tables":[]}`, nil
}

return string(data), nil
return fmt.Sprintf(`{"tables":%s}`, output), nil
}
4 changes: 2 additions & 2 deletions database/tool.gpt
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@ Share Tools: Run Database SQL
---
Name: Run Database SQL
Share Context: Database Context
Description: Run SQL against the SQLite3 database and return the results
Param: sql: SQL to run against the SQLite3 database (e.g. "SELECT * FROM users")
Description: Run SQL against the PostgreSQL database and return the results
Param: sql: SQL to run against the PostgreSQL database (e.g. "SELECT * FROM users")

#!${GPTSCRIPT_TOOL_DIR}/bin/gptscript-go-tool runDatabaseSQL