Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model File Manager #789

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions LLama.Unittest/Model/ModelCacheTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,52 @@ public ModelManagerTests()
TestableModelManager = new();
}

[Fact]
public async void LoadModel_DisposesOnUnload()
{
var modelToLoad = _testRepo.GetAvailableModels()
.First(f => f.ModelFileName.Contains("llama-2-7b"));

var model = await TestableModelManager.LoadModelAsync(modelToLoad);
Assert.NotNull(model);

// unloaded and disposed`
Assert.True(TestableModelManager.UnloadModel(model.ModelName));
Assert.Throws<ObjectDisposedException>(() =>
{
_ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri));
});

// wont unload and already
Assert.False(TestableModelManager.UnloadModel(model.ModelName));
Assert.Throws<ObjectDisposedException>(() =>
{
_ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri));
});
}

[Fact]
public async void LoadModel_LoadsAndCaches()
{
var modelToLoad = _testRepo.GetAvailableModels()
.First(f => f.ModelFileName.Contains("llama-2-7b"));

// Create Model -- Ref 1
var model = await TestableModelManager.LoadModelAsync(modelToLoad);
var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel);
Assert.True(isLoaded);
Assert.NotNull(model);

// clone it -- Ref 2
var isCachedAndCloned = TestableModelManager.TryCloneLoadedModel(model.ModelName, out var cachedModel);
Assert.True(isCachedAndCloned);
Assert.NotNull(cachedModel);

// unload the newly acquired model even though it was cached
cachedModel.Dispose(); //-- ref 1
Assert.True(TestableModelManager.UnloadModel(model.ModelName));
//cachedModel.Dispose(); // this does effectively nothing

// unload "original"
model.Dispose(); // need to explicitly dispose the model that the caller (us) owns
// unloaded and disposed` -- ref 2
Assert.True(TestableModelManager.UnloadModel(model.ModelName));

Assert.False(TestableModelManager.UnloadModel(model.ModelName));

Assert.Throws<ObjectDisposedException>(() =>
{
_ = model.CreateContext(new ModelParams(modelToLoad.ModelFileUri));
Expand All @@ -51,7 +77,7 @@ public async void LoadModel_AlreadyLoaded_ReturnsFromCache()
var model = await TestableModelManager.LoadModelAsync(modelToLoad);
Assert.NotNull(model);
Assert.Equal("LLaMA v2", model.ModelName);
var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel);
var isLoaded = TestableModelManager.TryCloneLoadedModel(model.ModelName, out var cachedModel);
Assert.True(isLoaded);
Assert.NotNull(cachedModel);
Assert.Equal("LLaMA v2", cachedModel.ModelName);
Expand All @@ -67,20 +93,20 @@ public async void TryGetLoadedModel_AlreadyDisposed_ReturnsFalse()
using (var model = await TestableModelManager.LoadModelAsync(modelToLoad))
{
Assert.NotNull(model);
Assert.Equal("LLaMA v2", model.ModelName);
var isLoaded = TestableModelManager.TryGetLoadedModel(model.ModelName, out var cachedModel);
Assert.Equal(model.ModelName, model.ModelName);
var isLoaded = TestableModelManager.TryCloneLoadedModel(model.ModelName, out var cachedModel);
Assert.True(isLoaded);
Assert.NotNull(cachedModel);
Assert.Equal("LLaMA v2", cachedModel.ModelName);
Assert.Equal(model.ModelName, cachedModel.ModelName);

// unload from the last check
Assert.True(TestableModelManager.UnloadModel("LLaMA v2"));
Assert.True(TestableModelManager.UnloadModel(model.ModelName));

} // end scope, dispose is called on the model but since we have the model cache it should stick around until unloaded
Assert.True(TestableModelManager.UnloadModel("LLaMA v2"));

// Model is still loaded due to cache
var isDisposedLoaded = TestableModelManager.TryGetLoadedModel("LLaMA v2", out var disposedModel);
var isDisposedLoaded = TestableModelManager.TryCloneLoadedModel("LLaMA v2", out var disposedModel);
Assert.False(isDisposedLoaded);
Assert.Null(disposedModel);
}
Expand Down
63 changes: 36 additions & 27 deletions LLama/LLamaWeights.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -18,12 +19,6 @@
{
private bool _disposed = false;

///
~LLamaWeights()
{
Dispose(false);
}

/// <summary>
/// The native handle, which is used in the native APIs
/// </summary>
Expand Down Expand Up @@ -86,15 +81,43 @@
NativeHandle.DangerousAddRef(ref success);
}

#region Load
/// <summary>
/// Create from a "shared" handle. The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until <b>all</b> such handles have been disposed.
/// Create an instance of the model using the supplied handle and metadata.
/// Metadata will <b>not</b> be re-read from the handle.
/// </summary>
/// <param name="handle"></param>
/// <param name="metadata"></param>
private LLamaWeights(SafeLlamaModelHandle handle, IReadOnlyDictionary<string, string> metadata)
{
NativeHandle = handle;
Metadata = metadata;

// Increment the model reference count while this weight exists.
// DangerousAddRef throws if it fails, so there is no need to check "success"
var success = false;
NativeHandle.DangerousAddRef(ref success);
}

/// <inheritdoc />
~LLamaWeights()
{
// Ensure the handle is released even if user's don't explicitly call Dispose
Dispose();
}

#region Load
/// <summary>
/// Create a new instance of the model using same NativeHandle as this model.
/// Metadata is also copied from the existing model rather than read from the hanlde directly

Check warning on line 111 in LLama/LLamaWeights.cs

View workflow job for this annotation

GitHub Actions / Spell check

"hanlde" should be "handle".
/// The `SafeLlamaModelHandle` will not be disposed and the model will not be unloaded until <b>ALL</b> such handles have been disposed.
/// </summary>
/// <returns></returns>
martindevans marked this conversation as resolved.
Show resolved Hide resolved
public static LLamaWeights FromSafeModelHandle(SafeLlamaModelHandle handle)
public LLamaWeights CloneFromHandleWithMetadata()
{
return new LLamaWeights(handle);
var metadataClone = Metadata
patrick-hovsepian marked this conversation as resolved.
Show resolved Hide resolved
.Select(x => x)
.ToDictionary(x => x.Key, x => x.Value);
return new LLamaWeights(NativeHandle, metadataClone);
}

/// <summary>
Expand Down Expand Up @@ -218,28 +241,14 @@
/// <inheritdoc />
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}

/// <summary>
/// Unload all models when called explicitly via dispose
/// </summary>
/// <param name="disposing">Whether or not this call is made explicitly(true) or via GC</param>
internal void Dispose(bool disposing)
{
if (_disposed)
{
return;
}

if (disposing)
if (!_disposed)
{
NativeHandle.DangerousRelease();
martindevans marked this conversation as resolved.
Show resolved Hide resolved
NativeHandle.Dispose();
_disposed = true;
}

_disposed = true;
GC.SuppressFinalize(this);
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion LLama/Model/HuggingFaceModelRepo.cs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

essentially for demo purposes. wanted to see how abstract the interface is

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is just a test do we want to keep it in this PR? I think @AsakusaRinne was working on HF integrations for model loading, so you might want to check what the status of that work is and add something in a separate PR?

Copy link
Contributor Author

@patrick-hovsepian patrick-hovsepian Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what the thoughts are on that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply, I agree that it could be removed from this PR and added in a separate PR to LLama.Experimental project. With HuggingfaceHub, it's easy to download a model from the huggingface. It's easy to implement a remote model manager for gguf files but the APIs might change in the future if we want to support other formats (.safetensors, .bin) based on GGMLSharp. So I would recommend putting it in the LLama.Experimental first.

Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public void AddSource(string source)
}

// TODO: call HF to check model exists
// TODO: Get metadta about model an
// TODO: Get metadata about model
_hfModelUri.Add(source);
}

Expand Down
2 changes: 1 addition & 1 deletion LLama/Model/IModelCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public interface IModelCache : IDisposable
/// </summary>
/// <param name="metadata"></param>
/// <param name="modelConfigurator"></param>
/// <param name="modelId">An alias to uniquely identify this model's underyling handle. If none is supplied, the model's name is used.'</param>
/// <param name="modelId">An alias to uniquely identify this model's underlying handle. If none is supplied, the model's name is used.'</param>
/// <param name="cancellationToken"></param>
/// <returns>The loaded model on success</returns>
public Task<LLamaWeights> LoadModelAsync(ModelFileMetadata metadata,
Expand Down
54 changes: 32 additions & 22 deletions LLama/Model/ModelCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,35 @@

namespace LLama.Model;

internal class CachedModelReference
patrick-hovsepian marked this conversation as resolved.
Show resolved Hide resolved
{
public LLamaWeights Model { get; init; } = null!;
public int RefCount { get; set; } = 0;
}

/// <inheritdoc />
public class ModelCache : IModelCache
{
private bool _disposed = false;

// model id/alias, to loaded model
private readonly Dictionary<string, SafeLlamaModelHandle> _loadedModelCache = [];
private readonly Dictionary<string, CachedModelReference> _loadedModelCache = [];

/// <inheritdoc />
public int ModelsCached()
public int ModelsCached()
=> _loadedModelCache.Count;

/// <inheritdoc />
public bool TryGetLoadedModel(string modelId, out LLamaWeights model)
public bool TryCloneLoadedModel(string modelId, out LLamaWeights model)
{
var isCached = _loadedModelCache.TryGetValue(modelId, out var handle);
model = isCached
? LLamaWeights.FromSafeModelHandle(handle)
: null!;
var isCached = _loadedModelCache.TryGetValue(modelId, out var cachedModel);

model = null!;
if (isCached)
{
model = cachedModel.Model.CloneFromHandleWithMetadata();
cachedModel.RefCount++;
}
return isCached;
}

Expand All @@ -40,7 +50,7 @@ public async Task<LLamaWeights> LoadModelAsync(ModelFileMetadata metadata,
{
// is the model already loaded? alias could be different but it's up to the caller to be consistent
if (!string.IsNullOrEmpty(modelId)
&& TryGetLoadedModel(modelId, out var loadedModel))
&& TryCloneLoadedModel(modelId, out var loadedModel))
{
return loadedModel;
}
Expand All @@ -58,33 +68,31 @@ public async Task<LLamaWeights> LoadModelAsync(ModelFileMetadata metadata,
{
modelId = model.ModelName;

if (TryGetLoadedModel(modelId, out loadedModel))
if (TryCloneLoadedModel(modelId, out loadedModel))
{
model.Dispose();
return loadedModel;
}
}

// Increment the model reference count while this model exists (newly created)
// DangerousAddRef throws if it fails, so there is no need to check "success"
// Do this here since we're passing this to the caller to own and it's not done as part of the normal weight creation
var refSuccess = false;
model.NativeHandle.DangerousAddRef(ref refSuccess);

_loadedModelCache.Add(modelId, model.NativeHandle);
_loadedModelCache.Add(modelId, new CachedModelReference
{
Model = model,
RefCount = 1
});
return model;
}

#region Unload
/// <inheritdoc />
public bool UnloadModel(string modelId)
{
if (_loadedModelCache.TryGetValue(modelId, out var handle))
if (_loadedModelCache.TryGetValue(modelId, out var cachedModel))
{
// Decrement refcount on model
handle.DangerousRelease();
handle.Dispose();
if (handle.IsClosed || handle.IsInvalid)
cachedModel.Model.Dispose(); // this only disposes the original model...
cachedModel.RefCount--;
if (cachedModel.RefCount == 0)
{
return _loadedModelCache.Remove(modelId);
}
Expand All @@ -98,8 +106,10 @@ public void UnloadAllModels()
{
foreach (var handle in _loadedModelCache.Values)
{
handle.DangerousRelease();
handle.Dispose();
for (var i = 0; i < handle.RefCount; i++)
{
handle.Model.Dispose();
}
}
_loadedModelCache.Clear();
}
Expand Down
Loading