diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs
index 4852a335..d344974d 100644
--- a/LLama.Unittest/Constants.cs
+++ b/LLama.Unittest/Constants.cs
@@ -4,6 +4,7 @@ namespace LLama.Unittest
{
internal static class Constants
{
+ public static readonly string ModelDirectory = "Models";
public static readonly string GenerativeModelPath = "Models/llama-2-7b-chat.Q3_K_S.gguf";
public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf";
diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj
index 5c29a851..087b1290 100644
--- a/LLama.Unittest/LLama.Unittest.csproj
+++ b/LLama.Unittest/LLama.Unittest.csproj
@@ -32,8 +32,6 @@
-
-
diff --git a/LLama.Unittest/Model/FileSystemModelRepoTests.cs b/LLama.Unittest/Model/FileSystemModelRepoTests.cs
new file mode 100644
index 00000000..c867ae62
--- /dev/null
+++ b/LLama.Unittest/Model/FileSystemModelRepoTests.cs
@@ -0,0 +1,104 @@
+using LLama.Model;
+
+namespace LLama.Unittest.Model;
+
+public class FileSystemModelRepoTests
+{
+ private readonly FileSystemModelRepo TestableRepo;
+
+ public FileSystemModelRepoTests()
+ {
+ TestableRepo = new([Constants.ModelDirectory]);
+ }
+
+ [Fact]
+ public void ModelDirectories_IsCorrect()
+ {
+ var dirs = TestableRepo.ListSources();
+ Assert.Single(dirs);
+
+ var expected = dirs.First()!.Contains(Constants.ModelDirectory);
+ Assert.True(expected);
+ }
+
+ [Fact]
+ public void AddDirectory_DoesntDuplicate()
+ {
+ for (var i = 0; i < 10; i++)
+ {
+ TestableRepo.AddSource(Constants.ModelDirectory);
+ TestableRepo.AddSource(Path.GetFullPath(Constants.ModelDirectory));
+
+ var dirs = TestableRepo.ListSources();
+ Assert.Single(dirs);
+ var expected = dirs.First()!.Contains(Constants.ModelDirectory);
+ Assert.True(expected);
+ }
+ }
+
+ [Fact]
+ public void RemoveDirectory()
+ {
+ var dirs = TestableRepo.ListSources();
+ Assert.Single(dirs);
+ var expected = dirs.First()!.Contains(Constants.ModelDirectory);
+ Assert.True(expected);
+
+ Assert.True(TestableRepo.RemoveSource(Constants.ModelDirectory));
+ Assert.Empty(TestableRepo.ListSources());
+ Assert.Empty(TestableRepo.GetAvailableModels());
+ }
+
+ [Fact]
+ public void RemoveDirectory_DoesNotExist()
+ {
+ var dirs = TestableRepo.ListSources();
+ Assert.Single(dirs);
+ var expected = dirs.First()!.Contains(Constants.ModelDirectory);
+ Assert.True(expected);
+
+ Assert.False(TestableRepo.RemoveSource("foo/boo/bar"));
+ Assert.Single(dirs);
+ }
+
+ [Fact]
+ public void RemoveAllDirectories()
+ {
+ var dirs = TestableRepo.ListSources();
+ Assert.Single(dirs);
+ var expected = dirs.First()!.Contains(Constants.ModelDirectory);
+ Assert.True(expected);
+
+ TestableRepo.RemoveAllSources();
+ Assert.Empty(TestableRepo.ListSources());
+ Assert.Empty(TestableRepo.GetAvailableModels());
+ }
+
+ [Fact]
+ public void ModelFiles_IsCorrect()
+ {
+ var files = TestableRepo.GetAvailableModels();
+ Assert.Equal(4, files.Count());
+ }
+
+ [Fact]
+ public void GetAvailableModelsFromDirectory()
+ {
+ var files = TestableRepo.GetAvailableModelsFromSource(Constants.ModelDirectory);
+ Assert.Equal(4, files.Count());
+
+ files = TestableRepo.GetAvailableModels();
+ Assert.Equal(4, files.Count());
+ }
+
+ [Fact]
+ public void TryGetModelFileMetadata_WhenExists()
+ {
+ var expectedFile = TestableRepo.GetAvailableModels().First();
+ var found = TestableRepo.TryGetModelFileMetadata(expectedFile.ModelFileUri, out var foundData);
+
+ Assert.True(found);
+ Assert.Equal(expectedFile.ModelFileUri, foundData.ModelFileUri);
+ }
+
+}
diff --git a/LLama.Unittest/Model/ModelCacheTests.cs b/LLama.Unittest/Model/ModelCacheTests.cs
new file mode 100644
index 00000000..af9275d9
--- /dev/null
+++ b/LLama.Unittest/Model/ModelCacheTests.cs
@@ -0,0 +1,125 @@
+using LLama.Common;
+using LLama.Model;
+
+namespace LLama.Unittest.Model;
+
+public class ModelManagerTests
+{
+ private readonly IModelSourceRepo _testRepo = new FileSystemModelRepo([Constants.ModelDirectory]);
+
+ private readonly ModelCache TestableModelManager;
+
+ public ModelManagerTests()
+ {
+ TestableModelManager = new();
+ }
+
+ [Fact]
+ public async void LoadModel_DisposesOnUnload()
+ {
+ const string modelId = "llama-2-7b";
+ var modelToLoad = _testRepo.GetAvailableModels()
+ .First(f => f.ModelFileName.Contains(modelId));
+
+ // Load success
+ var model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId);
+ Assert.NotNull(model);
+ Assert.Equal(1, TestableModelManager.ModelsCached());
+
+ // Load with same Id throws
+ await Assert.ThrowsAsync(async () =>
+ {
+ await TestableModelManager.LoadModelAsync(modelToLoad, modelId);
+ return;
+ });
+ Assert.Equal(1, TestableModelManager.ModelsCached());
+
+ // unloaded and disposed
+ Assert.True(TestableModelManager.UnloadModel(modelId));
+ Assert.Throws(() =>
+ {
+ _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri));
+ });
+ Assert.Equal(0, TestableModelManager.ModelsCached());
+
+ // already unloaded and disposed
+ Assert.False(TestableModelManager.UnloadModel(modelId));
+ Assert.Throws(() =>
+ {
+ _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri));
+ });
+
+ // Can be reloaded after unload
+ model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId);
+ Assert.NotNull(model);
+ Assert.Equal(1, TestableModelManager.ModelsCached());
+ Assert.True(TestableModelManager.UnloadModel(modelId));
+ Assert.Equal(0, TestableModelManager.ModelsCached());
+ }
+
+ [Fact]
+ public async void TryCloneLoadedModel_ClonesAndCaches()
+ {
+ const string modelId = "llama-2-7b";
+ var modelToLoad = _testRepo.GetAvailableModels()
+ .First(f => f.ModelFileName.Contains(modelId));
+
+ var model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId);
+ Assert.NotNull(model);
+ Assert.Equal(1, TestableModelManager.ModelsCached());
+
+ // clone it -- Ref 2
+ const string cloneId = nameof(cloneId);
+ var isCachedAndCloned = TestableModelManager.TryCloneLoadedModel(modelId, cloneId, out var cachedModel);
+ Assert.True(isCachedAndCloned);
+ Assert.NotNull(cachedModel);
+ Assert.Equal(2, TestableModelManager.ModelsCached());
+
+ cachedModel.Dispose(); //-- ref 1
+ Assert.True(TestableModelManager.UnloadModel(modelId));
+ Assert.Equal(1, TestableModelManager.ModelsCached());
+
+ // unloaded and disposed` -- ref 2
+ Assert.True(TestableModelManager.UnloadModel(cloneId));
+ Assert.Equal(0, TestableModelManager.ModelsCached());
+
+ Assert.False(TestableModelManager.UnloadModel(modelId));
+ Assert.False(TestableModelManager.UnloadModel(cloneId));
+ Assert.Throws(() =>
+ {
+ _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri));
+ });
+ Assert.Throws(() =>
+ {
+ _ = cachedModel.CreateContext(new ModelParams(modelToLoad.ModelFileUri));
+ });
+ }
+
+ [Fact]
+ public async void TryCloneLoadedModel_SameId_Throws()
+ {
+ const string modelId = "llama-2-7b";
+ var modelToLoad = _testRepo.GetAvailableModels()
+ .First(f => f.ModelFileName.Contains(modelId));
+
+ var model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId);
+ Assert.NotNull(model);
+ Assert.Equal(1, TestableModelManager.ModelsCached());
+
+ // Same Id clone fails
+ Assert.Throws(() =>
+ {
+ TestableModelManager.TryCloneLoadedModel(modelId, modelId, out var cachedModel);
+ });
+ Assert.Equal(1, TestableModelManager.ModelsCached());
+
+ // Unload and dispose
+ Assert.True(TestableModelManager.UnloadModel(modelId));
+ Assert.Equal(0, TestableModelManager.ModelsCached());
+ Assert.False(TestableModelManager.UnloadModel(modelId));
+ Assert.Throws(() =>
+ {
+ _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri));
+ });
+ }
+}
diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs
index 8646e4d9..aeaf3f5f 100644
--- a/LLama/LLamaWeights.cs
+++ b/LLama/LLamaWeights.cs
@@ -16,12 +16,23 @@ namespace LLama
public sealed class LLamaWeights
: IDisposable
{
+ private bool _disposed = false;
+
///
/// The native handle, which is used in the native APIs
///
/// Be careful how you use this!
public SafeLlamaModelHandle NativeHandle { get; }
+ #region Properties
+ ///
+ /// The models name as specified in it's metadata
+ ///
+ ///
+ public string ModelName => Metadata.TryGetValue("general.name", out var name)
+ ? name
+ : string.Empty;
+
///
/// Total number of tokens in vocabulary of this model
///
@@ -56,11 +67,53 @@ public sealed class LLamaWeights
/// All metadata keys in this model
///
public IReadOnlyDictionary Metadata { get; set; }
+ #endregion
+
+ private LLamaWeights(SafeLlamaModelHandle handle)
+ {
+ NativeHandle = handle;
+ Metadata = handle.ReadMetadata();
+
+ // Increment the model reference count while this weight exists.
+ // DangerousAddRef throws if it fails, so there is no need to check "success"
+ var success = false;
+ NativeHandle.DangerousAddRef(ref success);
+ }
+
+ ///
+ /// Create an instance of the model using the supplied handle and metadata.
+ /// Metadata will not be re-read from the handle.
+ ///
+ ///
+ ///
+ private LLamaWeights(SafeLlamaModelHandle handle, IReadOnlyDictionary metadata)
+ {
+ NativeHandle = handle;
+ Metadata = metadata;
+
+ // Increment the model reference count while this weight exists.
+ // DangerousAddRef throws if it fails, so there is no need to check "success"
+ var success = false;
+ NativeHandle.DangerousAddRef(ref success);
+ }
+
+ ///
+ ~LLamaWeights()
+ {
+ // Ensure the handle is released even if user's don't explicitly call Dispose
+ Dispose();
+ }
- private LLamaWeights(SafeLlamaModelHandle weights)
+ #region Load
+ ///
+ /// Create a new instance of the model using same NativeHandle as this model.
+ /// Metadata is also copied from the existing model rather than read from the handle directly
+ /// The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until ALL such handles have been disposed.
+ ///
+ ///
+ public LLamaWeights CloneFromHandleWithMetadata()
{
- NativeHandle = weights;
- Metadata = weights.ReadMetadata();
+ return new LLamaWeights(NativeHandle, Metadata);
}
///
@@ -71,19 +124,19 @@ private LLamaWeights(SafeLlamaModelHandle weights)
public static LLamaWeights LoadFromFile(IModelParams @params)
{
using var pin = @params.ToLlamaModelParams(out var lparams);
- var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
+ var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
foreach (var adapter in @params.LoraAdapters)
{
- if (string.IsNullOrEmpty(adapter.Path))
- continue;
- if (adapter.Scale <= 0)
+ if (string.IsNullOrEmpty(adapter.Path) || adapter.Scale <= 0)
+ {
continue;
+ }
- weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
+ model.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase);
}
- return new LLamaWeights(weights);
+ return new LLamaWeights(model);
}
///
@@ -103,15 +156,15 @@ public static async Task LoadFromFileAsync(IModelParams @params, C
var loraBase = @params.LoraBase;
var loraAdapters = @params.LoraAdapters.ToArray();
- // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
- // slightly smaller range to allow some space for reporting LoRA loading too.
- var modelLoadProgressRange = 1f;
- if (loraAdapters.Length > 0)
- modelLoadProgressRange = 0.9f;
-
using (@params.ToLlamaModelParams(out var lparams))
{
#if !NETSTANDARD2_0
+ // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
+ // slightly smaller range to allow some space for reporting LoRA loading too.
+ var modelLoadProgressRange = 1f;
+ if (loraAdapters.Length > 0)
+ modelLoadProgressRange = 0.9f;
+
// Overwrite the progress callback with one which polls the cancellation token and updates the progress object
if (token.CanBeCanceled || progressReporter != null)
{
@@ -125,11 +178,7 @@ public static async Task LoadFromFileAsync(IModelParams @params, C
if (internalCallback != null && !internalCallback(progress, ctx))
return false;
- // Check the cancellation token
- if (token.IsCancellationRequested)
- return false;
-
- return true;
+ return token.IsCancellationRequested;
};
}
#endif
@@ -183,11 +232,19 @@ public static async Task LoadFromFileAsync(IModelParams @params, C
return model;
}
}
+ #endregion
///
public void Dispose()
{
- NativeHandle.Dispose();
+ if (!_disposed)
+ {
+ NativeHandle.DangerousRelease();
+ NativeHandle.Dispose();
+ _disposed = true;
+ }
+
+ GC.SuppressFinalize(this);
}
///
diff --git a/LLama/Model/FileSystemModelRepo.cs b/LLama/Model/FileSystemModelRepo.cs
new file mode 100644
index 00000000..70987cb4
--- /dev/null
+++ b/LLama/Model/FileSystemModelRepo.cs
@@ -0,0 +1,119 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Linq;
+
+namespace LLama.Model;
+
+///
+/// A model repository that uses a file system to search for available models
+///
+public class FileSystemModelRepo : IModelSourceRepo
+{
+ ///
+ /// Support model type files
+ ///
+ public static readonly string[] ExpectedModelFileTypes = [
+ ".gguf"
+ ];
+
+ // keys are directories, values are applicable models
+ private readonly Dictionary> _availableModels = [];
+
+ ///
+ /// Create a model repo that scans the filesystem to find models
+ ///
+ ///
+ public FileSystemModelRepo(string[] directories)
+ {
+ GetModelsFromDirectories(directories);
+ }
+
+ #region Sources
+ ///
+ public IEnumerable ListSources() => _availableModels.Keys;
+
+ private void GetModelsFromDirectories(params string[] dirs)
+ {
+ foreach (var dir in dirs)
+ {
+ var fullDirectoryPath = Path.GetFullPath(dir);
+
+ if (!Directory.Exists(fullDirectoryPath))
+ {
+ Trace.TraceError($"Model directory '{fullDirectoryPath}' does not exist");
+ continue;
+ }
+
+ if (_availableModels.ContainsKey(fullDirectoryPath))
+ {
+ Trace.TraceWarning($"Model directory '{fullDirectoryPath}' already probed");
+ continue;
+ }
+
+ // find models in current dir that are of expected type
+ List directoryModelFiles = [];
+ foreach (var file in Directory.EnumerateFiles(fullDirectoryPath))
+ {
+ if (!ExpectedModelFileTypes.Contains(Path.GetExtension(file)))
+ {
+ continue;
+ }
+
+ // expected model file
+ // TODO: handle symbolic links
+ var fi = new FileInfo(file);
+ directoryModelFiles.Add(new ModelFileMetadata
+ {
+ ModelFileName = fi.Name,
+ ModelFileUri = fi.FullName,
+ ModelType = ModelFileType.GGUF,
+ ModelFileSizeInBytes = fi.Length,
+ });
+ }
+
+ _availableModels.Add(fullDirectoryPath, directoryModelFiles);
+ }
+ }
+
+ ///
+ public void AddSource(string directory)
+ {
+ GetModelsFromDirectories(directory);
+ }
+
+ ///
+ public bool RemoveSource(string directory)
+ {
+ return _availableModels.Remove(Path.GetFullPath(directory));
+ }
+
+ ///
+ public void RemoveAllSources()
+ {
+ _availableModels.Clear();
+ }
+ #endregion Sources
+
+ ///
+ public IEnumerable GetAvailableModels()
+ => _availableModels.SelectMany(x => x.Value);
+
+ ///
+ public IEnumerable GetAvailableModelsFromSource(string directory)
+ {
+ var dirPath = Path.GetFullPath(directory);
+ return _availableModels.TryGetValue(dirPath, out var dirModels)
+ ? dirModels
+ : [];
+ }
+
+ ///
+ public bool TryGetModelFileMetadata(string modelFileName, out ModelFileMetadata modelMeta)
+ {
+ var filePath = Path.GetFullPath(modelFileName);
+ modelMeta = GetAvailableModels().FirstOrDefault(f => f.ModelFileUri == filePath)!;
+ return modelMeta != null;
+ }
+}
diff --git a/LLama/Model/HuggingFaceModelRepo.cs b/LLama/Model/HuggingFaceModelRepo.cs
new file mode 100644
index 00000000..072e87e6
--- /dev/null
+++ b/LLama/Model/HuggingFaceModelRepo.cs
@@ -0,0 +1,57 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Net.Http;
+using Microsoft.Extensions.Logging;
+
+namespace LLama.Model;
+
+// This is for demo purposes - it can be finalized later
+internal class HuggingFaceModelRepo(ILogger logger,
+ HttpClient hfClient) : IModelSourceRepo
+{
+ private readonly ILogger _logger = logger;
+ private readonly HttpClient _hfClient = hfClient;
+
+ // https://huggingface.co/leliuga/all-MiniLM-L12-v2-GGUF/resolve/main/all-MiniLM-L12-v2.Q8_0.gguf
+ private readonly HashSet _hfModelUri = [];
+
+ public void AddSource(string source)
+ {
+ if (!Uri.IsWellFormedUriString(source, UriKind.Absolute))
+ {
+ Trace.TraceWarning("URI is not a valid HuggingFace URL");
+ }
+
+ // TODO: call HF to check model exists
+ // TODO: Get metadata about model
+ _hfModelUri.Add(source);
+ }
+
+ public IEnumerable ListSources() => _hfModelUri;
+
+ public void RemoveAllSources()
+ {
+ _hfModelUri.Clear();
+ }
+
+ public bool RemoveSource(string source)
+ {
+ return _hfModelUri.Remove(source);
+ }
+
+ public bool TryGetModelFileMetadata(string modelFileName, out ModelFileMetadata modelMeta)
+ {
+ throw new NotImplementedException();
+ }
+
+ public IEnumerable GetAvailableModels()
+ {
+ throw new NotImplementedException();
+ }
+
+ public IEnumerable GetAvailableModelsFromSource(string source)
+ {
+ throw new NotImplementedException();
+ }
+}
diff --git a/LLama/Model/IModelCache.cs b/LLama/Model/IModelCache.cs
new file mode 100644
index 00000000..09327222
--- /dev/null
+++ b/LLama/Model/IModelCache.cs
@@ -0,0 +1,61 @@
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using LLama.Common;
+
+namespace LLama.Model;
+
+///
+/// A class that helps organize and load local models
+///
+public interface IModelCache : IDisposable
+{
+ ///
+ /// The current number of file handles in cache.
+ ///
+ /// Number of cached models
+ public int ModelsCached();
+
+ ///
+ /// Load a model file to be used for inference.
+ /// The caller assumes responsibility for disposing this model and MUST call UnloadModel
+ ///
+ /// The metadata about the model file to be loaded
+ /// A required alias to uniquely identify this model'
+ /// An optional function to further configure the model parameters beyond default
+ ///
+ /// An instance of the newly loaded model. This MUST be disposed or Unload
+ public Task LoadModelAsync(ModelFileMetadata metadata,
+ string modelId,
+ Action? modelConfigurator = null!,
+ CancellationToken cancellationToken = default);
+
+ ///
+ /// Attempt to get a reference to a model that's already loaded
+ ///
+ /// Identifier of the loaded model
+ /// Will be populated with the reference if the model is cached
+ /// A SHARED instance to a model that's already loaded. Disposing or Unloading this model will affect all references
+ public bool TryGetLoadedModel(string modelId, out LLamaWeights cachedModel);
+
+ ///
+ /// Attempt to clone and cache a new unique model instance
+ ///
+ /// Model that's expected to be loaded and cloned
+ /// Identifier for the newly cloned model
+ /// If cloning is successful, this model will be available for use
+ /// True if cloning is successful
+ public bool TryCloneLoadedModel(string loadedModelId, string cloneId, out LLamaWeights model);
+
+ ///
+ /// Unload and dispose of a model with the given id
+ ///
+ ///
+ ///
+ public bool UnloadModel(string modelId);
+
+ ///
+ /// Unload all currently loaded models
+ ///
+ public void UnloadAllModels();
+}
diff --git a/LLama/Model/IModelSourceRepo.cs b/LLama/Model/IModelSourceRepo.cs
new file mode 100644
index 00000000..c502cbe8
--- /dev/null
+++ b/LLama/Model/IModelSourceRepo.cs
@@ -0,0 +1,58 @@
+using System.Collections.Generic;
+
+namespace LLama.Model;
+
+///
+/// A source for models
+///
+public interface IModelSourceRepo
+{
+ #region Source
+ ///
+ /// Configured set of sources that are scanned to find models
+ ///
+ ///
+ public IEnumerable ListSources();
+
+ ///
+ /// Add a source containing one or more files
+ ///
+ ///
+ public void AddSource(string source);
+
+ ///
+ /// Remove a source from being scanned and having model files made available
+ ///
+ ///
+ ///
+ public bool RemoveSource(string source);
+
+ ///
+ /// Remove all model directories
+ ///
+ public void RemoveAllSources();
+ #endregion
+
+ #region AvailableModels
+ ///
+ /// Get all of the model files that are available to be loaded
+ ///
+ ///
+ public IEnumerable GetAvailableModels();
+
+ ///
+ /// Only get the models associated with a specific source
+ ///
+ ///
+ /// The files, if any associated with a given source
+ public IEnumerable GetAvailableModelsFromSource(string source);
+
+ ///
+ /// Get the file data for given model
+ ///
+ ///
+ ///
+ /// If a model with the given file name is present
+ public bool TryGetModelFileMetadata(string modelFileName, out ModelFileMetadata modelMeta);
+ #endregion
+}
diff --git a/LLama/Model/ModelCache.cs b/LLama/Model/ModelCache.cs
new file mode 100644
index 00000000..933156a3
--- /dev/null
+++ b/LLama/Model/ModelCache.cs
@@ -0,0 +1,147 @@
+using System;
+using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
+using LLama.Common;
+
+namespace LLama.Model;
+
+internal class CachedModelReference
+{
+ public LLamaWeights Model { get; init; } = null!;
+ public int RefCount { get; set; } = 0;
+}
+
+///
+public class ModelCache : IModelCache
+{
+ private bool _disposed = false;
+
+ // model id/alias, to loaded model
+ private readonly Dictionary _loadedModelCache = [];
+
+ ///
+ public int ModelsCached()
+ => _loadedModelCache.Count;
+
+ ///
+ public bool TryCloneLoadedModel(string loadedModelId,
+ string cloneId,
+ out LLamaWeights model)
+ {
+ var isCached = _loadedModelCache.TryGetValue(loadedModelId, out var cachedModel);
+
+ model = null!;
+ if (isCached)
+ {
+ TryAddModel(cloneId, cachedModel.CloneFromHandleWithMetadata);
+ model = _loadedModelCache[loadedModelId];
+ return true;
+ }
+ return false;
+ }
+
+ ///
+ public bool TryGetLoadedModel(string modelId, out LLamaWeights cachedModel)
+ {
+ return _loadedModelCache.TryGetValue(modelId, out cachedModel);
+ }
+
+ private void TryAddModel(string modelId, Func modelCreator)
+ {
+ if (IsModelIdInvalid(modelId))
+ {
+ throw new ArgumentException("Model identifier is not unique");
+ }
+
+ _loadedModelCache.Add(modelId, modelCreator());
+ }
+
+ private async Task TryAddModelAsync(string modelId, Func> modelCreator)
+ {
+ if (IsModelIdInvalid(modelId))
+ {
+ throw new ArgumentException("Model identifier is not unique");
+ }
+
+ _loadedModelCache.Add(modelId, await modelCreator());
+ }
+
+ private bool IsModelIdInvalid(string modelId) =>
+ string.IsNullOrWhiteSpace(modelId) || _loadedModelCache.ContainsKey(modelId);
+
+ ///
+ public async Task LoadModelAsync(ModelFileMetadata metadata,
+ string modelId,
+ Action? modelConfigurator = null!,
+ CancellationToken cancellationToken = default)
+ {
+ await TryAddModelAsync(modelId, async () =>
+ {
+ return await ModelCreator(metadata.ModelFileUri, modelConfigurator, cancellationToken);
+ });
+ return _loadedModelCache[modelId];
+
+ // Helper to create the model
+ static async Task ModelCreator(string fileUri,
+ Action? modelConfigurator,
+ CancellationToken cancellationToken)
+ {
+ var modelParams = new ModelParams(fileUri);
+ modelConfigurator?.Invoke(modelParams);
+
+ return await LLamaWeights.LoadFromFileAsync(modelParams, cancellationToken);
+ }
+ }
+
+ #region Unload
+ ///
+ public bool UnloadModel(string modelId)
+ {
+ if (_loadedModelCache.TryGetValue(modelId, out var cachedModel))
+ {
+ cachedModel.Dispose();
+ return _loadedModelCache.Remove(modelId);
+ }
+ return false;
+ }
+
+ ///
+ public void UnloadAllModels()
+ {
+ foreach (var model in _loadedModelCache.Values)
+ {
+ model.Dispose();
+ }
+ _loadedModelCache.Clear();
+ }
+ #endregion
+
+ #region Dispose
+ ///
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ ///
+ /// Unload all models when called explicitly via dispose
+ ///
+ /// Whether or not this call is made explicitly(true) or via GC
+ protected virtual void Dispose(bool disposing)
+ {
+ if (_disposed)
+ {
+ return;
+ }
+
+ if (disposing)
+ {
+ UnloadAllModels();
+ }
+
+ _disposed = true;
+ }
+ #endregion
+}
diff --git a/LLama/Model/ModelFileMetadata.cs b/LLama/Model/ModelFileMetadata.cs
new file mode 100644
index 00000000..3b2c1814
--- /dev/null
+++ b/LLama/Model/ModelFileMetadata.cs
@@ -0,0 +1,23 @@
+namespace LLama.Model;
+
+#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
+///
+/// Types of supported model files
+///
+public enum ModelFileType
+{
+ GGUF
+}
+
+///
+/// Metadata about available models
+///
+public class ModelFileMetadata
+{
+ public string ModelFileName { get; init; } = string.Empty;
+ public string ModelFileUri { get; init; } = string.Empty;
+ public ModelFileType ModelType { get; init; }
+ public long ModelFileSizeInBytes { get; init; } = 0;
+}
+#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member
+
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 1597908e..c25f0b4f 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -16,6 +16,7 @@ namespace LLama.Native
public sealed class SafeLlamaModelHandle
: SafeLLamaHandleBase
{
+ #region Properties
///
/// Total number of tokens in vocabulary of this model
///
@@ -61,6 +62,7 @@ public sealed class SafeLlamaModelHandle
///
public int LayerCount => llama_n_embd(this);
+ private string _modelDescription = null!;
///
/// Get a description of this model
///
@@ -68,17 +70,22 @@ public string Description
{
get
{
- unsafe
+ if (_modelDescription is null)
{
- // Get description length
- var size = llama_model_desc(this, null, 0);
- var buf = new byte[size + 1];
- fixed (byte* bufPtr = buf)
+ unsafe
{
- size = llama_model_desc(this, bufPtr, buf.Length);
- return Encoding.UTF8.GetString(buf, 0, size);
+ // Get description length
+ var size = llama_model_desc(this, null, 0);
+ var buf = new byte[size + 1];
+ fixed (byte* bufPtr = buf)
+ {
+ size = llama_model_desc(this, bufPtr, buf.Length);
+ _modelDescription = Encoding.UTF8.GetString(buf, 0, size) ?? string.Empty;
+ }
}
}
+
+ return _modelDescription;
}
}
@@ -94,6 +101,7 @@ public string Description
/// Get the special tokens of this model
///
public ModelTokens Tokens => _tokens ??= new ModelTokens(this);
+ #endregion
///
protected override bool ReleaseHandle()
@@ -101,7 +109,7 @@ protected override bool ReleaseHandle()
llama_free_model(handle);
return true;
}
-
+
///
/// Load a model from the given file path into memory
///
@@ -116,12 +124,18 @@ public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelPara
// - File is readable (explicit check)
// This provides better error messages that llama.cpp, which would throw an access violation exception in both cases.
using (var fs = new FileStream(modelPath, FileMode.Open))
+ {
if (!fs.CanRead)
+ {
throw new InvalidOperationException($"Model file '{modelPath}' is not readable");
+ }
+ }
var handle = llama_load_model_from_file(modelPath, lparams);
if (handle.IsInvalid)
+ {
throw new LoadWeightsFailedException(modelPath);
+ }
return handle;
}
@@ -244,7 +258,6 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k
static extern unsafe int llama_model_meta_val_str_native(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size);
}
-
///
/// Get the number of tokens in the model vocabulary
///
@@ -545,7 +558,7 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
keyLength = llama_model_meta_val_str(this, key, buffer);
Debug.Assert(keyLength >= 0);
- return buffer.AsMemory().Slice(0,keyLength);
+ return buffer.AsMemory().Slice(0, keyLength);
}
///
@@ -632,12 +645,12 @@ internal ModelTokens(SafeLlamaModelHandle model)
const int buffSize = 32;
Span buff = stackalloc byte[buffSize];
var tokenLength = _model.TokenToSpan(token ?? LLamaToken.InvalidToken, buff, special: isSpecialToken);
-
+
if (tokenLength <= 0)
{
return null;
}
-
+
// if the original buffer wasn't large enough, create a new one
if (tokenLength > buffSize)
{
@@ -663,7 +676,7 @@ internal ModelTokens(SafeLlamaModelHandle model)
/// Get the End of Sentence token for this model
///
public LLamaToken? EOS => Normalize(llama_token_eos(_model));
-
+
///
/// The textual representation of the end of speech special token for this model
///