diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs index 4852a335..d344974d 100644 --- a/LLama.Unittest/Constants.cs +++ b/LLama.Unittest/Constants.cs @@ -4,6 +4,7 @@ namespace LLama.Unittest { internal static class Constants { + public static readonly string ModelDirectory = "Models"; public static readonly string GenerativeModelPath = "Models/llama-2-7b-chat.Q3_K_S.gguf"; public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf"; diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 5c29a851..087b1290 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -32,8 +32,6 @@ - - diff --git a/LLama.Unittest/Model/FileSystemModelRepoTests.cs b/LLama.Unittest/Model/FileSystemModelRepoTests.cs new file mode 100644 index 00000000..c867ae62 --- /dev/null +++ b/LLama.Unittest/Model/FileSystemModelRepoTests.cs @@ -0,0 +1,104 @@ +using LLama.Model; + +namespace LLama.Unittest.Model; + +public class FileSystemModelRepoTests +{ + private readonly FileSystemModelRepo TestableRepo; + + public FileSystemModelRepoTests() + { + TestableRepo = new([Constants.ModelDirectory]); + } + + [Fact] + public void ModelDirectories_IsCorrect() + { + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + } + + [Fact] + public void AddDirectory_DoesntDuplicate() + { + for (var i = 0; i < 10; i++) + { + TestableRepo.AddSource(Constants.ModelDirectory); + TestableRepo.AddSource(Path.GetFullPath(Constants.ModelDirectory)); + + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + } + } + + [Fact] + public void RemoveDirectory() + { + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + Assert.True(TestableRepo.RemoveSource(Constants.ModelDirectory)); + Assert.Empty(TestableRepo.ListSources()); + Assert.Empty(TestableRepo.GetAvailableModels()); + } + + [Fact] + public void RemoveDirectory_DoesNotExist() + { + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + Assert.False(TestableRepo.RemoveSource("foo/boo/bar")); + Assert.Single(dirs); + } + + [Fact] + public void RemoveAllDirectories() + { + var dirs = TestableRepo.ListSources(); + Assert.Single(dirs); + var expected = dirs.First()!.Contains(Constants.ModelDirectory); + Assert.True(expected); + + TestableRepo.RemoveAllSources(); + Assert.Empty(TestableRepo.ListSources()); + Assert.Empty(TestableRepo.GetAvailableModels()); + } + + [Fact] + public void ModelFiles_IsCorrect() + { + var files = TestableRepo.GetAvailableModels(); + Assert.Equal(4, files.Count()); + } + + [Fact] + public void GetAvailableModelsFromDirectory() + { + var files = TestableRepo.GetAvailableModelsFromSource(Constants.ModelDirectory); + Assert.Equal(4, files.Count()); + + files = TestableRepo.GetAvailableModels(); + Assert.Equal(4, files.Count()); + } + + [Fact] + public void TryGetModelFileMetadata_WhenExists() + { + var expectedFile = TestableRepo.GetAvailableModels().First(); + var found = TestableRepo.TryGetModelFileMetadata(expectedFile.ModelFileUri, out var foundData); + + Assert.True(found); + Assert.Equal(expectedFile.ModelFileUri, foundData.ModelFileUri); + } + +} diff --git a/LLama.Unittest/Model/ModelCacheTests.cs b/LLama.Unittest/Model/ModelCacheTests.cs new file mode 100644 index 00000000..af9275d9 --- /dev/null +++ b/LLama.Unittest/Model/ModelCacheTests.cs @@ -0,0 +1,125 @@ +using LLama.Common; +using LLama.Model; + +namespace LLama.Unittest.Model; + +public class ModelManagerTests +{ + private readonly IModelSourceRepo _testRepo = new FileSystemModelRepo([Constants.ModelDirectory]); + + private readonly ModelCache TestableModelManager; + + public ModelManagerTests() + { + TestableModelManager = new(); + } + + [Fact] + public async void LoadModel_DisposesOnUnload() + { + const string modelId = "llama-2-7b"; + var modelToLoad = _testRepo.GetAvailableModels() + .First(f => f.ModelFileName.Contains(modelId)); + + // Load success + var model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId); + Assert.NotNull(model); + Assert.Equal(1, TestableModelManager.ModelsCached()); + + // Load with same Id throws + await Assert.ThrowsAsync(async () => + { + await TestableModelManager.LoadModelAsync(modelToLoad, modelId); + return; + }); + Assert.Equal(1, TestableModelManager.ModelsCached()); + + // unloaded and disposed + Assert.True(TestableModelManager.UnloadModel(modelId)); + Assert.Throws(() => + { + _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); + }); + Assert.Equal(0, TestableModelManager.ModelsCached()); + + // already unloaded and disposed + Assert.False(TestableModelManager.UnloadModel(modelId)); + Assert.Throws(() => + { + _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); + }); + + // Can be reloaded after unload + model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId); + Assert.NotNull(model); + Assert.Equal(1, TestableModelManager.ModelsCached()); + Assert.True(TestableModelManager.UnloadModel(modelId)); + Assert.Equal(0, TestableModelManager.ModelsCached()); + } + + [Fact] + public async void TryCloneLoadedModel_ClonesAndCaches() + { + const string modelId = "llama-2-7b"; + var modelToLoad = _testRepo.GetAvailableModels() + .First(f => f.ModelFileName.Contains(modelId)); + + var model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId); + Assert.NotNull(model); + Assert.Equal(1, TestableModelManager.ModelsCached()); + + // clone it -- Ref 2 + const string cloneId = nameof(cloneId); + var isCachedAndCloned = TestableModelManager.TryCloneLoadedModel(modelId, cloneId, out var cachedModel); + Assert.True(isCachedAndCloned); + Assert.NotNull(cachedModel); + Assert.Equal(2, TestableModelManager.ModelsCached()); + + cachedModel.Dispose(); //-- ref 1 + Assert.True(TestableModelManager.UnloadModel(modelId)); + Assert.Equal(1, TestableModelManager.ModelsCached()); + + // unloaded and disposed` -- ref 2 + Assert.True(TestableModelManager.UnloadModel(cloneId)); + Assert.Equal(0, TestableModelManager.ModelsCached()); + + Assert.False(TestableModelManager.UnloadModel(modelId)); + Assert.False(TestableModelManager.UnloadModel(cloneId)); + Assert.Throws(() => + { + _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); + }); + Assert.Throws(() => + { + _ = cachedModel.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); + }); + } + + [Fact] + public async void TryCloneLoadedModel_SameId_Throws() + { + const string modelId = "llama-2-7b"; + var modelToLoad = _testRepo.GetAvailableModels() + .First(f => f.ModelFileName.Contains(modelId)); + + var model = await TestableModelManager.LoadModelAsync(modelToLoad, modelId); + Assert.NotNull(model); + Assert.Equal(1, TestableModelManager.ModelsCached()); + + // Same Id clone fails + Assert.Throws(() => + { + TestableModelManager.TryCloneLoadedModel(modelId, modelId, out var cachedModel); + }); + Assert.Equal(1, TestableModelManager.ModelsCached()); + + // Unload and dispose + Assert.True(TestableModelManager.UnloadModel(modelId)); + Assert.Equal(0, TestableModelManager.ModelsCached()); + Assert.False(TestableModelManager.UnloadModel(modelId)); + Assert.Throws(() => + { + _ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri)); + }); + } +} diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 8646e4d9..aeaf3f5f 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -16,12 +16,23 @@ namespace LLama public sealed class LLamaWeights : IDisposable { + private bool _disposed = false; + /// /// The native handle, which is used in the native APIs /// /// Be careful how you use this! public SafeLlamaModelHandle NativeHandle { get; } + #region Properties + /// + /// The models name as specified in it's metadata + /// + /// + public string ModelName => Metadata.TryGetValue("general.name", out var name) + ? name + : string.Empty; + /// /// Total number of tokens in vocabulary of this model /// @@ -56,11 +67,53 @@ public sealed class LLamaWeights /// All metadata keys in this model /// public IReadOnlyDictionary Metadata { get; set; } + #endregion + + private LLamaWeights(SafeLlamaModelHandle handle) + { + NativeHandle = handle; + Metadata = handle.ReadMetadata(); + + // Increment the model reference count while this weight exists. + // DangerousAddRef throws if it fails, so there is no need to check "success" + var success = false; + NativeHandle.DangerousAddRef(ref success); + } + + /// + /// Create an instance of the model using the supplied handle and metadata. + /// Metadata will not be re-read from the handle. + /// + /// + /// + private LLamaWeights(SafeLlamaModelHandle handle, IReadOnlyDictionary metadata) + { + NativeHandle = handle; + Metadata = metadata; + + // Increment the model reference count while this weight exists. + // DangerousAddRef throws if it fails, so there is no need to check "success" + var success = false; + NativeHandle.DangerousAddRef(ref success); + } + + /// + ~LLamaWeights() + { + // Ensure the handle is released even if user's don't explicitly call Dispose + Dispose(); + } - private LLamaWeights(SafeLlamaModelHandle weights) + #region Load + /// + /// Create a new instance of the model using same NativeHandle as this model. + /// Metadata is also copied from the existing model rather than read from the handle directly + /// The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until ALL such handles have been disposed. + /// + /// + public LLamaWeights CloneFromHandleWithMetadata() { - NativeHandle = weights; - Metadata = weights.ReadMetadata(); + return new LLamaWeights(NativeHandle, Metadata); } /// @@ -71,19 +124,19 @@ private LLamaWeights(SafeLlamaModelHandle weights) public static LLamaWeights LoadFromFile(IModelParams @params) { using var pin = @params.ToLlamaModelParams(out var lparams); - var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); + var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); foreach (var adapter in @params.LoraAdapters) { - if (string.IsNullOrEmpty(adapter.Path)) - continue; - if (adapter.Scale <= 0) + if (string.IsNullOrEmpty(adapter.Path) || adapter.Scale <= 0) + { continue; + } - weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase); + model.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase); } - return new LLamaWeights(weights); + return new LLamaWeights(model); } /// @@ -103,15 +156,15 @@ public static async Task LoadFromFileAsync(IModelParams @params, C var loraBase = @params.LoraBase; var loraAdapters = @params.LoraAdapters.ToArray(); - // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a - // slightly smaller range to allow some space for reporting LoRA loading too. - var modelLoadProgressRange = 1f; - if (loraAdapters.Length > 0) - modelLoadProgressRange = 0.9f; - using (@params.ToLlamaModelParams(out var lparams)) { #if !NETSTANDARD2_0 + // Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a + // slightly smaller range to allow some space for reporting LoRA loading too. + var modelLoadProgressRange = 1f; + if (loraAdapters.Length > 0) + modelLoadProgressRange = 0.9f; + // Overwrite the progress callback with one which polls the cancellation token and updates the progress object if (token.CanBeCanceled || progressReporter != null) { @@ -125,11 +178,7 @@ public static async Task LoadFromFileAsync(IModelParams @params, C if (internalCallback != null && !internalCallback(progress, ctx)) return false; - // Check the cancellation token - if (token.IsCancellationRequested) - return false; - - return true; + return token.IsCancellationRequested; }; } #endif @@ -183,11 +232,19 @@ public static async Task LoadFromFileAsync(IModelParams @params, C return model; } } + #endregion /// public void Dispose() { - NativeHandle.Dispose(); + if (!_disposed) + { + NativeHandle.DangerousRelease(); + NativeHandle.Dispose(); + _disposed = true; + } + + GC.SuppressFinalize(this); } /// diff --git a/LLama/Model/FileSystemModelRepo.cs b/LLama/Model/FileSystemModelRepo.cs new file mode 100644 index 00000000..70987cb4 --- /dev/null +++ b/LLama/Model/FileSystemModelRepo.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; + +namespace LLama.Model; + +/// +/// A model repository that uses a file system to search for available models +/// +public class FileSystemModelRepo : IModelSourceRepo +{ + /// + /// Support model type files + /// + public static readonly string[] ExpectedModelFileTypes = [ + ".gguf" + ]; + + // keys are directories, values are applicable models + private readonly Dictionary> _availableModels = []; + + /// + /// Create a model repo that scans the filesystem to find models + /// + /// + public FileSystemModelRepo(string[] directories) + { + GetModelsFromDirectories(directories); + } + + #region Sources + /// + public IEnumerable ListSources() => _availableModels.Keys; + + private void GetModelsFromDirectories(params string[] dirs) + { + foreach (var dir in dirs) + { + var fullDirectoryPath = Path.GetFullPath(dir); + + if (!Directory.Exists(fullDirectoryPath)) + { + Trace.TraceError($"Model directory '{fullDirectoryPath}' does not exist"); + continue; + } + + if (_availableModels.ContainsKey(fullDirectoryPath)) + { + Trace.TraceWarning($"Model directory '{fullDirectoryPath}' already probed"); + continue; + } + + // find models in current dir that are of expected type + List directoryModelFiles = []; + foreach (var file in Directory.EnumerateFiles(fullDirectoryPath)) + { + if (!ExpectedModelFileTypes.Contains(Path.GetExtension(file))) + { + continue; + } + + // expected model file + // TODO: handle symbolic links + var fi = new FileInfo(file); + directoryModelFiles.Add(new ModelFileMetadata + { + ModelFileName = fi.Name, + ModelFileUri = fi.FullName, + ModelType = ModelFileType.GGUF, + ModelFileSizeInBytes = fi.Length, + }); + } + + _availableModels.Add(fullDirectoryPath, directoryModelFiles); + } + } + + /// + public void AddSource(string directory) + { + GetModelsFromDirectories(directory); + } + + /// + public bool RemoveSource(string directory) + { + return _availableModels.Remove(Path.GetFullPath(directory)); + } + + /// + public void RemoveAllSources() + { + _availableModels.Clear(); + } + #endregion Sources + + /// + public IEnumerable GetAvailableModels() + => _availableModels.SelectMany(x => x.Value); + + /// + public IEnumerable GetAvailableModelsFromSource(string directory) + { + var dirPath = Path.GetFullPath(directory); + return _availableModels.TryGetValue(dirPath, out var dirModels) + ? dirModels + : []; + } + + /// + public bool TryGetModelFileMetadata(string modelFileName, out ModelFileMetadata modelMeta) + { + var filePath = Path.GetFullPath(modelFileName); + modelMeta = GetAvailableModels().FirstOrDefault(f => f.ModelFileUri == filePath)!; + return modelMeta != null; + } +} diff --git a/LLama/Model/HuggingFaceModelRepo.cs b/LLama/Model/HuggingFaceModelRepo.cs new file mode 100644 index 00000000..072e87e6 --- /dev/null +++ b/LLama/Model/HuggingFaceModelRepo.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using Microsoft.Extensions.Logging; + +namespace LLama.Model; + +// This is for demo purposes - it can be finalized later +internal class HuggingFaceModelRepo(ILogger logger, + HttpClient hfClient) : IModelSourceRepo +{ + private readonly ILogger _logger = logger; + private readonly HttpClient _hfClient = hfClient; + + // https://huggingface.co/leliuga/all-MiniLM-L12-v2-GGUF/resolve/main/all-MiniLM-L12-v2.Q8_0.gguf + private readonly HashSet _hfModelUri = []; + + public void AddSource(string source) + { + if (!Uri.IsWellFormedUriString(source, UriKind.Absolute)) + { + Trace.TraceWarning("URI is not a valid HuggingFace URL"); + } + + // TODO: call HF to check model exists + // TODO: Get metadata about model + _hfModelUri.Add(source); + } + + public IEnumerable ListSources() => _hfModelUri; + + public void RemoveAllSources() + { + _hfModelUri.Clear(); + } + + public bool RemoveSource(string source) + { + return _hfModelUri.Remove(source); + } + + public bool TryGetModelFileMetadata(string modelFileName, out ModelFileMetadata modelMeta) + { + throw new NotImplementedException(); + } + + public IEnumerable GetAvailableModels() + { + throw new NotImplementedException(); + } + + public IEnumerable GetAvailableModelsFromSource(string source) + { + throw new NotImplementedException(); + } +} diff --git a/LLama/Model/IModelCache.cs b/LLama/Model/IModelCache.cs new file mode 100644 index 00000000..09327222 --- /dev/null +++ b/LLama/Model/IModelCache.cs @@ -0,0 +1,61 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using LLama.Common; + +namespace LLama.Model; + +/// +/// A class that helps organize and load local models +/// +public interface IModelCache : IDisposable +{ + /// + /// The current number of file handles in cache. + /// + /// Number of cached models + public int ModelsCached(); + + /// + /// Load a model file to be used for inference. + /// The caller assumes responsibility for disposing this model and MUST call UnloadModel + /// + /// The metadata about the model file to be loaded + /// A required alias to uniquely identify this model' + /// An optional function to further configure the model parameters beyond default + /// + /// An instance of the newly loaded model. This MUST be disposed or Unload + public Task LoadModelAsync(ModelFileMetadata metadata, + string modelId, + Action? modelConfigurator = null!, + CancellationToken cancellationToken = default); + + /// + /// Attempt to get a reference to a model that's already loaded + /// + /// Identifier of the loaded model + /// Will be populated with the reference if the model is cached + /// A SHARED instance to a model that's already loaded. Disposing or Unloading this model will affect all references + public bool TryGetLoadedModel(string modelId, out LLamaWeights cachedModel); + + /// + /// Attempt to clone and cache a new unique model instance + /// + /// Model that's expected to be loaded and cloned + /// Identifier for the newly cloned model + /// If cloning is successful, this model will be available for use + /// True if cloning is successful + public bool TryCloneLoadedModel(string loadedModelId, string cloneId, out LLamaWeights model); + + /// + /// Unload and dispose of a model with the given id + /// + /// + /// + public bool UnloadModel(string modelId); + + /// + /// Unload all currently loaded models + /// + public void UnloadAllModels(); +} diff --git a/LLama/Model/IModelSourceRepo.cs b/LLama/Model/IModelSourceRepo.cs new file mode 100644 index 00000000..c502cbe8 --- /dev/null +++ b/LLama/Model/IModelSourceRepo.cs @@ -0,0 +1,58 @@ +using System.Collections.Generic; + +namespace LLama.Model; + +/// +/// A source for models +/// +public interface IModelSourceRepo +{ + #region Source + /// + /// Configured set of sources that are scanned to find models + /// + /// + public IEnumerable ListSources(); + + /// + /// Add a source containing one or more files + /// + /// + public void AddSource(string source); + + /// + /// Remove a source from being scanned and having model files made available + /// + /// + /// + public bool RemoveSource(string source); + + /// + /// Remove all model directories + /// + public void RemoveAllSources(); + #endregion + + #region AvailableModels + /// + /// Get all of the model files that are available to be loaded + /// + /// + public IEnumerable GetAvailableModels(); + + /// + /// Only get the models associated with a specific source + /// + /// + /// The files, if any associated with a given source + public IEnumerable GetAvailableModelsFromSource(string source); + + /// + /// Get the file data for given model + /// + /// + /// + /// If a model with the given file name is present + public bool TryGetModelFileMetadata(string modelFileName, out ModelFileMetadata modelMeta); + #endregion +} diff --git a/LLama/Model/ModelCache.cs b/LLama/Model/ModelCache.cs new file mode 100644 index 00000000..933156a3 --- /dev/null +++ b/LLama/Model/ModelCache.cs @@ -0,0 +1,147 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using LLama.Common; + +namespace LLama.Model; + +internal class CachedModelReference +{ + public LLamaWeights Model { get; init; } = null!; + public int RefCount { get; set; } = 0; +} + +/// +public class ModelCache : IModelCache +{ + private bool _disposed = false; + + // model id/alias, to loaded model + private readonly Dictionary _loadedModelCache = []; + + /// + public int ModelsCached() + => _loadedModelCache.Count; + + /// + public bool TryCloneLoadedModel(string loadedModelId, + string cloneId, + out LLamaWeights model) + { + var isCached = _loadedModelCache.TryGetValue(loadedModelId, out var cachedModel); + + model = null!; + if (isCached) + { + TryAddModel(cloneId, cachedModel.CloneFromHandleWithMetadata); + model = _loadedModelCache[loadedModelId]; + return true; + } + return false; + } + + /// + public bool TryGetLoadedModel(string modelId, out LLamaWeights cachedModel) + { + return _loadedModelCache.TryGetValue(modelId, out cachedModel); + } + + private void TryAddModel(string modelId, Func modelCreator) + { + if (IsModelIdInvalid(modelId)) + { + throw new ArgumentException("Model identifier is not unique"); + } + + _loadedModelCache.Add(modelId, modelCreator()); + } + + private async Task TryAddModelAsync(string modelId, Func> modelCreator) + { + if (IsModelIdInvalid(modelId)) + { + throw new ArgumentException("Model identifier is not unique"); + } + + _loadedModelCache.Add(modelId, await modelCreator()); + } + + private bool IsModelIdInvalid(string modelId) => + string.IsNullOrWhiteSpace(modelId) || _loadedModelCache.ContainsKey(modelId); + + /// + public async Task LoadModelAsync(ModelFileMetadata metadata, + string modelId, + Action? modelConfigurator = null!, + CancellationToken cancellationToken = default) + { + await TryAddModelAsync(modelId, async () => + { + return await ModelCreator(metadata.ModelFileUri, modelConfigurator, cancellationToken); + }); + return _loadedModelCache[modelId]; + + // Helper to create the model + static async Task ModelCreator(string fileUri, + Action? modelConfigurator, + CancellationToken cancellationToken) + { + var modelParams = new ModelParams(fileUri); + modelConfigurator?.Invoke(modelParams); + + return await LLamaWeights.LoadFromFileAsync(modelParams, cancellationToken); + } + } + + #region Unload + /// + public bool UnloadModel(string modelId) + { + if (_loadedModelCache.TryGetValue(modelId, out var cachedModel)) + { + cachedModel.Dispose(); + return _loadedModelCache.Remove(modelId); + } + return false; + } + + /// + public void UnloadAllModels() + { + foreach (var model in _loadedModelCache.Values) + { + model.Dispose(); + } + _loadedModelCache.Clear(); + } + #endregion + + #region Dispose + /// + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + /// + /// Unload all models when called explicitly via dispose + /// + /// Whether or not this call is made explicitly(true) or via GC + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + + if (disposing) + { + UnloadAllModels(); + } + + _disposed = true; + } + #endregion +} diff --git a/LLama/Model/ModelFileMetadata.cs b/LLama/Model/ModelFileMetadata.cs new file mode 100644 index 00000000..3b2c1814 --- /dev/null +++ b/LLama/Model/ModelFileMetadata.cs @@ -0,0 +1,23 @@ +namespace LLama.Model; + +#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member +/// +/// Types of supported model files +/// +public enum ModelFileType +{ + GGUF +} + +/// +/// Metadata about available models +/// +public class ModelFileMetadata +{ + public string ModelFileName { get; init; } = string.Empty; + public string ModelFileUri { get; init; } = string.Empty; + public ModelFileType ModelType { get; init; } + public long ModelFileSizeInBytes { get; init; } = 0; +} +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member + diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 1597908e..c25f0b4f 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -16,6 +16,7 @@ namespace LLama.Native public sealed class SafeLlamaModelHandle : SafeLLamaHandleBase { + #region Properties /// /// Total number of tokens in vocabulary of this model /// @@ -61,6 +62,7 @@ public sealed class SafeLlamaModelHandle /// public int LayerCount => llama_n_embd(this); + private string _modelDescription = null!; /// /// Get a description of this model /// @@ -68,17 +70,22 @@ public string Description { get { - unsafe + if (_modelDescription is null) { - // Get description length - var size = llama_model_desc(this, null, 0); - var buf = new byte[size + 1]; - fixed (byte* bufPtr = buf) + unsafe { - size = llama_model_desc(this, bufPtr, buf.Length); - return Encoding.UTF8.GetString(buf, 0, size); + // Get description length + var size = llama_model_desc(this, null, 0); + var buf = new byte[size + 1]; + fixed (byte* bufPtr = buf) + { + size = llama_model_desc(this, bufPtr, buf.Length); + _modelDescription = Encoding.UTF8.GetString(buf, 0, size) ?? string.Empty; + } } } + + return _modelDescription; } } @@ -94,6 +101,7 @@ public string Description /// Get the special tokens of this model /// public ModelTokens Tokens => _tokens ??= new ModelTokens(this); + #endregion /// protected override bool ReleaseHandle() @@ -101,7 +109,7 @@ protected override bool ReleaseHandle() llama_free_model(handle); return true; } - + /// /// Load a model from the given file path into memory /// @@ -116,12 +124,18 @@ public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelPara // - File is readable (explicit check) // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases. using (var fs = new FileStream(modelPath, FileMode.Open)) + { if (!fs.CanRead) + { throw new InvalidOperationException($"Model file '{modelPath}' is not readable"); + } + } var handle = llama_load_model_from_file(modelPath, lparams); if (handle.IsInvalid) + { throw new LoadWeightsFailedException(modelPath); + } return handle; } @@ -244,7 +258,6 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k static extern unsafe int llama_model_meta_val_str_native(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); } - /// /// Get the number of tokens in the model vocabulary /// @@ -545,7 +558,7 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) keyLength = llama_model_meta_val_str(this, key, buffer); Debug.Assert(keyLength >= 0); - return buffer.AsMemory().Slice(0,keyLength); + return buffer.AsMemory().Slice(0, keyLength); } /// @@ -632,12 +645,12 @@ internal ModelTokens(SafeLlamaModelHandle model) const int buffSize = 32; Span buff = stackalloc byte[buffSize]; var tokenLength = _model.TokenToSpan(token ?? LLamaToken.InvalidToken, buff, special: isSpecialToken); - + if (tokenLength <= 0) { return null; } - + // if the original buffer wasn't large enough, create a new one if (tokenLength > buffSize) { @@ -663,7 +676,7 @@ internal ModelTokens(SafeLlamaModelHandle model) /// Get the End of Sentence token for this model /// public LLamaToken? EOS => Normalize(llama_token_eos(_model)); - + /// /// The textual representation of the end of speech special token for this model ///