Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions sd_csharp_test.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using Xunit;
using Xunit.Abstractions;
using Microsoft.Extensions.AI;
/*
namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
{
public class OnnxRuntimeGenAITests
{
[Fact(DisplayName = "TestStableDiffusion")]
public void TestStableDiffusion()
{
string modelPath = "C:\\Users\\yangselena\\onnxruntime-genai\\onnxruntime-genai\\test\\test_models\\sd";
using (var model = new Model(modelPath))
{

using ImageGeneratorParams imageGeneratorParams = new ImageGeneratorParams(model);
Assert.NotNull(imageGeneratorParams);

imageGeneratorParams.SetPrompts("a photo of a cat");

var imageTensor = Generator.GenerateImage(model, imageGeneratorParams);

Assert.NotNull(imageTensor);
}
}
}
}*/
8 changes: 8 additions & 0 deletions src/csharp/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ public Generator(Model model, GeneratorParams generatorParams)
Result.VerifySuccess(NativeMethods.OgaCreateGenerator(model.Handle, generatorParams.Handle, out _generatorHandle));
}

public static Tensor GenerateImage(Model model, ImageGeneratorParams imageGeneratorParams)
{
Result.VerifySuccess(NativeMethods.OgaGenerateImage(model.Handle,
imageGeneratorParams.Handle,
out IntPtr outputTensor));
return new Tensor(outputTensor);
}

public bool IsDone()
{
return NativeMethods.OgaGenerator_IsDone(_generatorHandle) != 0;
Expand Down
118 changes: 118 additions & 0 deletions src/csharp/ImageGeneratorParams.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
public class ImageGeneratorParams : IDisposable
{
private IntPtr _imageGeneratorParamsHandle;
private bool _disposed = false;

public ImageGeneratorParams(Model model)
{
Result.VerifySuccess(NativeMethods.OgaCreateImageGeneratorParams(model.Handle, out _imageGeneratorParamsHandle));
}

internal IntPtr Handle { get { return _imageGeneratorParamsHandle; } }

/*
public void SetPrompts(string prompt, string negativePrompt = null, int promptCount = 1)
{
// Create the byte arrays and hold references to them
byte[] promptBytes = StringUtils.ToUtf8(prompt);
byte[] negativePromptBytes = negativePrompt != null ? StringUtils.ToUtf8(negativePrompt) : null;

// Pass the byte arrays to the native method
Result.VerifySuccess(NativeMethods.OgaImageGeneratorParamsSetPrompts(
_imageGeneratorParamsHandle,
promptBytes,
negativePromptBytes,
promptCount));
}*/

/*
public void SetPrompts(string[] prompts, string[] negativePrompts)
{
if (negativePrompts != null && prompts.Length != negativePrompts.Length)
{
throw new ArgumentException("Prompts and negative prompts arrays must have the same length");
}

// Implementation for multiple prompts would go here
// This would require modifying the C API to accept arrays of strings
throw new NotImplementedException("Multiple prompts are not yet supported");
}*/
/*

public void SetPrompts(string prompt, string negativePrompt = null, int promptCount = 1)
{
// Create the byte arrays
byte[] promptBytes = StringUtils.ToUtf8(prompt);
byte[] negativePromptBytes = negativePrompt != null ? StringUtils.ToUtf8(negativePrompt) : null;

// Pin memory so GC doesn't move it during the native call
GCHandle promptHandle = GCHandle.Alloc(promptBytes, GCHandleType.Pinned);
GCHandle? negativePromptHandle = null;

if (negativePromptBytes != null)
{
negativePromptHandle = GCHandle.Alloc(negativePromptBytes, GCHandleType.Pinned);
}

try
{
// Call the native method with pinned memory
Result.VerifySuccess(NativeMethods.OgaImageGeneratorParamsSetPrompts(
_imageGeneratorParamsHandle,
promptBytes,
negativePromptBytes,
promptCount));
}
finally
{
// Always unpin memory in finally block
promptHandle.Free();
if (negativePromptHandle.HasValue)
{
negativePromptHandle.Value.Free();
}
}
}*/
public void SetPrompts(string prompt)
{
Result.VerifySuccess(NativeMethods.OgaImageGeneratorParamsSetPrompts(
_imageGeneratorParamsHandle,
StringUtils.ToUtf8(prompt),
null,
1));

}



~ImageGeneratorParams()
{
Dispose(false);
}

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}

protected virtual void Dispose(bool disposing)
{
if (_disposed)
{
return;
}
NativeMethods.OgaDestroyImageGeneratorParams(_imageGeneratorParamsHandle);
_imageGeneratorParamsHandle = IntPtr.Zero;
_disposed = true;
}
}
}
21 changes: 21 additions & 0 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Drawing;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
Expand Down Expand Up @@ -322,5 +323,25 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaUnloadAdapter(IntPtr /* OgaAdapters* */ adapters,
byte[] /* const char* */ adapterName);

// Image generation functions
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateImageGeneratorParams(IntPtr /* const OgaModel* */ model,
out IntPtr /* OgaImageGeneratorParams* */ imageGeneratorParams);


[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaImageGeneratorParamsSetPrompts(IntPtr /* OgaImageGeneratorParams* */ imageGeneratorParams,
byte[] /* const char* */ prompt,
byte[] /* const char* */ negativePrompt,
int /* int */ promptCount);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerateImage(IntPtr /* const OgaModel* */ model,
IntPtr /* const OgaImageGeneratorParams* */ imageGeneratorParams,
out IntPtr /* OgaTensor* */ tensor);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyImageGeneratorParams(IntPtr /* OgaImageGeneratorParams* */ imageGeneratorParams);
}
}
4 changes: 4 additions & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ DeviceInterface* GetDeviceInterface(DeviceType type) {
}
}

ImageGeneratorParams::ImageGeneratorParams(const Model& model)
: config{*model.config_.get()} {}


GeneratorParams::GeneratorParams(const Config& config)
: config{config},
p_device{GetDeviceInterface(DeviceType::CPU)} {
Expand Down
9 changes: 9 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ using TokenSequences = std::vector<std::vector<int32_t>>;
std::string to_string(DeviceType device_type);
DeviceInterface* GetDeviceInterface(DeviceType type);

struct ImageGeneratorParams : std::enable_shared_from_this<ImageGeneratorParams>, LeakChecked<ImageGeneratorParams>, ExternalRefCounted<ImageGeneratorParams> {
ImageGeneratorParams(const Model& model);

const Config& config; // The model outlives the GeneratorParams

std::vector<std::string> prompts;
std::vector<std::string> negative_prompts;
};

struct GeneratorParams : std::enable_shared_from_this<GeneratorParams>, LeakChecked<GeneratorParams>, ExternalRefCounted<GeneratorParams> {
GeneratorParams(const Config& config); // This constructor is only used for internal generator benchmarks
GeneratorParams(const Model& model);
Expand Down
23 changes: 23 additions & 0 deletions src/java/src/main/java/ai/onnxruntime/genai/Generator.java
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,26 @@ public Integer next() {
}
}

/**
* Generates an image using the specified model and image generation parameters.
*
* @param model The diffusion model to use for generation.
* @param params Parameters controlling the image generation process.
* @return A Tensor containing the generated image data.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public static Tensor generateImage(Model model, ImageGeneratorParams params)
throws GenAIException {
if (model.nativeHandle() == 0) {
throw new IllegalStateException("model has been freed and is invalid");
}
if (params.nativeHandle() == 0) {
throw new IllegalStateException("params has been freed and is invalid");
}

return new Tensor(generateImageNative(model.nativeHandle(), params.nativeHandle()));
}

static {
try {
GenAI.init();
Expand Down Expand Up @@ -240,4 +260,7 @@ private native void setActiveAdapter(
long nativeHandle, long adaptersNativeHandle, String adapterName) throws GenAIException;

private native long getOutputNative(long nativeHandle, String outputName) throws GenAIException;

private static native long generateImageNative(long modelHandle, long paramsHandle)
throws GenAIException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.genai;

/** Represents the parameters used for generating images with a diffusion model. */
public final class ImageGeneratorParams implements AutoCloseable {
private long nativeHandle = 0;

/**
* Creates ImageGeneratorParams from the given model.
*
* @param model The model to use for image generation.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public ImageGeneratorParams(Model model) throws GenAIException {
if (model.nativeHandle() == 0) {
throw new IllegalStateException("model has been freed and is invalid");
}

nativeHandle = createImageGeneratorParams(model.nativeHandle());
}

/**
* Sets the prompt and optional negative prompt for image generation.
*
* @param prompt The text prompt describing the desired image.
* @param negativePrompt Optional text describing what to avoid in the image. Can be null.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void setPrompts(String prompt, String negativePrompt) throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}

String[] prompts = new String[] {prompt};
String[] negativePrompts = negativePrompt != null ? new String[] {negativePrompt} : null;
setPrompts(nativeHandle, prompts, negativePrompts, 1);
}

/**
* Sets multiple prompts and optional negative prompts for batch image generation.
*
* @param prompts Array of text prompts describing the desired images.
* @param negativePrompts Optional array of texts describing what to avoid in each image. Can be
* null.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void setPrompts(String[] prompts, String[] negativePrompts) throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}

if (negativePrompts != null && prompts.length != negativePrompts.length) {
throw new IllegalArgumentException(
"prompts and negativePrompts arrays must be the same length");
}

setPrompts(nativeHandle, prompts, negativePrompts, prompts.length);
}

@Override
public void close() {
if (nativeHandle != 0) {
destroyImageGeneratorParams(nativeHandle);
nativeHandle = 0;
}
}

/**
* Returns the native handle for this object.
*
* @return The native handle.
*/
long nativeHandle() {
return nativeHandle;
}

static {
try {
GenAI.init();
} catch (Exception e) {
throw new RuntimeException("Failed to load onnxruntime-genai native libraries", e);
}
}

private native long createImageGeneratorParams(long modelHandle) throws GenAIException;

private native void destroyImageGeneratorParams(long nativeHandle);

private native void setPrompts(
long nativeHandle, String[] prompts, String[] negativePrompts, long promptCount)
throws GenAIException;
}
13 changes: 13 additions & 0 deletions src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,16 @@ Java_ai_onnxruntime_genai_Generator_getOutputNative(JNIEnv* env, jobject thiz, j
}
return reinterpret_cast<jlong>(tensor);
}

JNIEXPORT jlong JNICALL
Java_ai_onnxruntime_genai_Generator_generateImageNative(JNIEnv* env, jclass cls, jlong model_handle, jlong params_handle) {
const OgaModel* model = reinterpret_cast<const OgaModel*>(model_handle);
const OgaImageGeneratorParams* params = reinterpret_cast<const OgaImageGeneratorParams*>(params_handle);

OgaTensor* tensor = nullptr;
if (ThrowIfError(env, OgaGenerateImage(model, params, &tensor))) {
return 0;
}

return reinterpret_cast<jlong>(tensor);
}
Loading