Skip to content

version number support and extended unit tests with testcontainer.mssql #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions DotPrompt.Sql.Cli/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
var config = DatabaseConfigReader.ReadYamlConfig(args[1]);
var connector = new DatabaseConnector();
var connection = await connector.ConnectToDatabase(config);
var loader = new SqlPromptLoader(connection);
await loader.AddSqlPrompt(entity);
IPromptRepository sqlRepository = new SqlPromptRepository(connection);
var loader = new SqlPromptLoader(sqlRepository);
bool upVersioned = await loader.AddSqlPrompt(entity);

Check warning on line 17 in DotPrompt.Sql.Cli/Program.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'entity' in 'Task<bool> SqlPromptLoader.AddSqlPrompt(SqlPromptEntity entity)'.

Check warning on line 17 in DotPrompt.Sql.Cli/Program.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'entity' in 'Task<bool> SqlPromptLoader.AddSqlPrompt(SqlPromptEntity entity)'.
Console.WriteLine($"Done: {upVersioned}");

var prompts = loader.Load();
foreach (var prompt in prompts)
Expand Down
2 changes: 1 addition & 1 deletion DotPrompt.Sql.Cli/prompts/basic.prompt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model: claude-3-5-sonnet-latest
config:
name: basic1
name: basic
outputFormat: text
temperature: 0.9
maxTokens: 500
Expand Down
6 changes: 6 additions & 0 deletions DotPrompt.Sql.Test/DotPrompt.Sql.Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

<ItemGroup>
<PackageReference Include="coverlet.collector" Version="6.0.0"/>
<PackageReference Include="Microsoft.Data.Sqlite" Version="9.0.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.8.0"/>
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.11-pre20241216174303" />
<PackageReference Include="System.Data.SQLite" Version="1.0.119" />
<PackageReference Include="System.Data.SQLite.Core" Version="1.0.119" />
<PackageReference Include="TestContainers.Container.Database.MsSql" Version="1.5.4" />
<PackageReference Include="Testcontainers.MsSql" Version="4.1.0" />
<PackageReference Include="xunit" Version="2.5.3"/>
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.3"/>
</ItemGroup>
Expand Down
231 changes: 142 additions & 89 deletions DotPrompt.Sql.Test/TestSqlPromptEntity.cs
Original file line number Diff line number Diff line change
@@ -1,121 +1,174 @@
namespace DotPrompt.Sql.Test;

using System;
using System.Collections.Generic;
using System.IO;
using System.Data;
using System.Data.SQLite;
using System.Reflection;
using Microsoft.Data.Sqlite;
using System.Threading.Tasks;
using Dapper;
using DotPrompt.Sql;
using Microsoft.Data.SqlClient;
using Testcontainers.MsSql;
using Xunit;

public class SqlPromptEntityTests : IDisposable
public class SqlPromptRepositoryTests : IAsyncLifetime
{
private readonly List<string> _testFiles = new();
private readonly MsSqlContainer _sqlServerContainer;
private IDbConnection _connection;
private SqlPromptRepository _repository;

// Helper method to create and track test YAML files
private void CreateTestYamlFile(string filePath, string content)
public SqlPromptRepositoryTests()

Check warning on line 20 in DotPrompt.Sql.Test/TestSqlPromptEntity.cs

View workflow job for this annotation

GitHub Actions / build

Non-nullable field '_connection' must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring the field as nullable.

Check warning on line 20 in DotPrompt.Sql.Test/TestSqlPromptEntity.cs

View workflow job for this annotation

GitHub Actions / build

Non-nullable field '_repository' must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring the field as nullable.

Check warning on line 20 in DotPrompt.Sql.Test/TestSqlPromptEntity.cs

View workflow job for this annotation

GitHub Actions / build

Non-nullable field '_connection' must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring the field as nullable.

Check warning on line 20 in DotPrompt.Sql.Test/TestSqlPromptEntity.cs

View workflow job for this annotation

GitHub Actions / build

Non-nullable field '_repository' must contain a non-null value when exiting constructor. Consider adding the 'required' modifier or declaring the field as nullable.
{
File.WriteAllText(filePath, content);
_testFiles.Add(filePath); // Track the file for later cleanup
_sqlServerContainer = new MsSqlBuilder()
.WithImage("mcr.microsoft.com/mssql/server:2022-latest")
.WithPassword("YourStrong(!)Password")
.Build();
}

[Fact]
public void FromPromptFile_ValidYaml_ReturnsSqlPromptEntity()
public async Task InitializeAsync()
{
// Arrange
string filePath = "test_prompt.yaml";
string yamlContent = @"
model: gpt-4
config:
name: TestPrompt
outputFormat: json
maxTokens: 200
input:
parameters:
param1: value1
param2: value2
default:
param1: default1
param2: default2
prompts:
system: System message
user: User message";

CreateTestYamlFile(filePath, yamlContent);
await _sqlServerContainer.StartAsync();

// Act
var result = SqlPromptEntity.FromPromptFile(filePath);
_connection = new SqlConnection(_sqlServerContainer.GetConnectionString());
_connection.Open();

// Assert
Assert.NotNull(result);
Assert.Equal("gpt-4", result.Model);
Assert.Equal("TestPrompt", result.PromptName);
Assert.Equal("json", result.OutputFormat);
Assert.Equal(200, result.MaxTokens);
Assert.Equal("System message", result.SystemPrompt);
Assert.Equal("User message", result.UserPrompt);
Assert.Equal("value1", result.Parameters["param1"]);
Assert.Equal("default1", result.Default["param1"]);
_repository = new SqlPromptRepository(_connection);
await InitializeDatabase();
}

[Fact]
public void FromPromptFile_FileDoesNotExist_ThrowsFileNotFoundException()
public Task DisposeAsync()
{
// Arrange
string invalidPath = "non_existent.yaml";
_sqlServerContainer.StopAsync();
_connection?.Dispose();
return Task.CompletedTask;
}

private static string LoadSql(string resourceName)
{
// Get the assembly containing the embedded SQL files
var assembly = Assembly.Load("DotPrompt.Sql"); // Name of the referenced assembly

// Find the full resource name (includes namespace path)
string? fullResourceName = assembly.GetManifestResourceNames()
.FirstOrDefault(name => name.EndsWith(resourceName, StringComparison.OrdinalIgnoreCase));

if (fullResourceName == null)
{
throw new FileNotFoundException($"Resource {resourceName} not found in assembly {assembly.FullName}");
}

// Read the embedded resource stream
using var stream = assembly.GetManifestResourceStream(fullResourceName);
using var reader = new StreamReader(stream);

Check warning on line 62 in DotPrompt.Sql.Test/TestSqlPromptEntity.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'stream' in 'StreamReader.StreamReader(Stream stream)'.

Check warning on line 62 in DotPrompt.Sql.Test/TestSqlPromptEntity.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'stream' in 'StreamReader.StreamReader(Stream stream)'.
return reader.ReadToEnd();
}
private async Task InitializeDatabase()
{
string tables = LoadSql("CreateDefaultPromptTables.sql");
await _connection.ExecuteAsync(tables);

// Act & Assert
var exception = Assert.Throws<FileNotFoundException>(() => SqlPromptEntity.FromPromptFile(invalidPath));
Assert.Contains("The specified file was not found", exception.Message);
string procs = LoadSql("AddSqlPrompt.sql");
await _connection.ExecuteAsync(procs);
}

[Fact]
public void FromPromptFile_MissingMandatoryFields_ThrowsException()
public async Task AddSqlPrompt_ValidPrompt_InsertsSuccessfully()
{
// Arrange
string filePath = "test_missing_optional.yaml";
string yamlContent = @"
config:
name: TestPrompt
outputFormat: json
maxTokens: 100
prompts:
system: Default system prompt
user: Default user prompt";

CreateTestYamlFile(filePath, yamlContent);

// Act & Assert
Assert.Throws<ApplicationException>(() => SqlPromptEntity.FromPromptFile(filePath));
var entity = new SqlPromptEntity
{
PromptName = "myprompt",
Model = "gpt4",
OutputFormat = "json",
MaxTokens = 500,
SystemPrompt = "Optimize SQL queries.",
UserPrompt = "Suggest indexing improvements.",
Parameters = new Dictionary<string, string>
{
{ "Temperature", "0.7" },
{ "TopP", "0.9" }
},
Default = new Dictionary<string, object>
{
{ "Temperature", "0.5" }
}
};

// Act
bool result = await _repository.AddSqlPrompt(entity);

// Assert
Assert.True(result, "Expected new prompt version to be inserted.");
}

[Fact]
public void FromPromptFile_InvalidDataType_ThrowsException()
public async Task AddSqlPrompt_SamePromptNoChanges_DoesNotInsertNewVersion()
{
// Arrange
string filePath = "test_invalid_type.yaml";
string yamlContent = @"
model: gpt-4
config:
name: TestPrompt
outputFormat: json
maxTokens: not_a_number
prompts:
system: System prompt
user: User prompt";

CreateTestYamlFile(filePath, yamlContent);

// Act & Assert
Assert.Throws<FormatException>(() => SqlPromptEntity.FromPromptFile(filePath));
var entity = new SqlPromptEntity
{
PromptName = "myprompt",
Model = "gpt4",
OutputFormat = "json",
MaxTokens = 200,
SystemPrompt = "Optimize SQL queries.",
UserPrompt = "Suggest indexing improvements.",
Parameters = new Dictionary<string, string>
{
{ "Temperature", "0.7" },
{ "TopP", "0.9" }
},
Default = new Dictionary<string, object>
{
{ "Temperature", "0.5" }
}
};

await _repository.AddSqlPrompt(entity); // First insert

// Act
bool result = await _repository.AddSqlPrompt(entity); // Try inserting again with no changes

// Assert
Assert.False(result, "No new version should be inserted when nothing has changed.");
}

// Cleanup method called after each test
public void Dispose()
[Fact]
public async Task AddSqlPrompt_WhenMaxTokensChanges_ShouldInsertNewVersion()
{
foreach (var file in _testFiles)
// Arrange
var entity1 = new SqlPromptEntity
{
if (File.Exists(file))
{
File.Delete(file);
}
}
PromptName = "noprompt",
Model = "gpt4",
OutputFormat = "json",
MaxTokens = 500,
SystemPrompt = "Optimize SQL queries.",
UserPrompt = "Suggest indexing improvements.",
Parameters = new Dictionary<string, string> { { "Temperature", "0.7" } },
Default = new Dictionary<string, object> { { "Temperature", "0.5" } }
};

var entity2 = new SqlPromptEntity
{
PromptName = "noprompt", // Same prompt name
Model = "gpt4",
OutputFormat = "json",
MaxTokens = 512, // Changed value
SystemPrompt = "Optimize SQL queries.",
UserPrompt = "Suggest indexing improvements.",
Parameters = new Dictionary<string, string> { { "Temperature", "0.7" } },
Default = new Dictionary<string, object> { { "Temperature", "0.5" } }
};

await _repository.AddSqlPrompt(entity1); // Insert first version

// Act
bool result = await _repository.AddSqlPrompt(entity2);

// Assert
Assert.True(result, "A new version should be inserted when MaxTokens changes.");
}


}
9 changes: 9 additions & 0 deletions DotPrompt.Sql/DatabaseConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public async Task<IDbConnection> ConnectToDatabase(DatabaseConfig config)
var connection = new SqlConnection(connectionString);
await connection.OpenAsync();
await CreatePromptTables(connection);
await CreateStoredProcs(connection);
Console.WriteLine("Connected to the database successfully!");
return connection;
}
Expand All @@ -42,6 +43,14 @@ private async Task CreatePromptTables(SqlConnection connection)
await command.ExecuteNonQueryAsync();
}

private async Task CreateStoredProcs(SqlConnection connection)
{
// 1. Does the prompt table exist already
string? sqlCreate = DatabaseConfigReader.LoadQuery("AddSqlPrompt.sql");
await using SqlCommand command = new SqlCommand(sqlCreate, connection);
await command.ExecuteNonQueryAsync();
}

private static string BuildConnectionString(DatabaseConfig config)
{
if (config.AadAuthentication)
Expand Down
2 changes: 2 additions & 0 deletions DotPrompt.Sql/DotPrompt.Sql.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
<EmbeddedResource Include="Resources\SqlQueries\InsertPromptDefaults.sql" />
<None Remove="Resources\SqlQueries\CreateDefaultPromptTables.sql" />
<EmbeddedResource Include="Resources\SqlQueries\CreateDefaultPromptTables.sql" />
<None Remove="Resources\SqlQueries\AddSqlPrompt.sql" />
<EmbeddedResource Include="Resources\SqlQueries\AddSqlPrompt.sql" />
</ItemGroup>

</Project>
19 changes: 19 additions & 0 deletions DotPrompt.Sql/IPromptRepository.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace DotPrompt.Sql;

/// <summary>
/// Defines a prompt repository which will be injected into a loader
/// </summary>
public interface IPromptRepository
{
/// <summary>
/// Adds a SQL prompt and upversions the prompt if it's changed
/// </summary>
/// <param name="entity">The prompt entity that is being added or upversioned</param>
/// <returns>A boolean to denote whether it added the prompt or not</returns>
Task<bool> AddSqlPrompt(SqlPromptEntity entity);
/// <summary>
/// Loads all instances of the prompt from the catalog but only the latest versions
/// </summary>
/// <returns>An enumeration of prompts with different names</returns>
IEnumerable<SqlPromptEntity> Load();
}
Loading