Skip to content

Commit

Permalink
Add an API to get the loaded native library.
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakusaRinne committed May 1, 2024
1 parent 2c19b8b commit a86f14d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
20 changes: 10 additions & 10 deletions LLama/Native/Load/NativeLibraryConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ public sealed partial class NativeLibraryConfig

static NativeLibraryConfig()
{
LLama = new(NativeLibraryName.Llama);
LLava = new(NativeLibraryName.LlavaShared);
LLama = new(NativeLibraryName.LLama);
LLava = new(NativeLibraryName.LLava);
All = new(LLama, LLava);
}

Expand Down Expand Up @@ -401,11 +401,11 @@ public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? llava
{
foreach(var config in _configs)
{
if(config.NativeLibraryName == NativeLibraryName.Llama && llamaPath is not null)
if(config.NativeLibraryName == NativeLibraryName.LLama && llamaPath is not null)
{
config.WithLibrary(llamaPath);
}
if(config.NativeLibraryName == NativeLibraryName.LlavaShared && llavaPath is not null)
if(config.NativeLibraryName == NativeLibraryName.LLava && llavaPath is not null)
{
config.WithLibrary(llavaPath);
}
Expand Down Expand Up @@ -567,11 +567,11 @@ public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibr
foreach(var config in _configs)
{
success &= config.DryRun(out var loadedLibrary);
if(config.NativeLibraryName == NativeLibraryName.Llama)
if(config.NativeLibraryName == NativeLibraryName.LLama)
{
loadedLLamaNativeLibrary = loadedLibrary;
}
else if(config.NativeLibraryName == NativeLibraryName.LlavaShared)
else if(config.NativeLibraryName == NativeLibraryName.LLava)
{
loadedLLavaNativeLibrary = loadedLibrary;
}
Expand All @@ -593,11 +593,11 @@ public enum NativeLibraryName
/// <summary>
/// The native library compiled from llama.cpp.
/// </summary>
Llama,
LLama,
/// <summary>
/// The native library compiled from the LLaVA example of llama.cpp.
/// </summary>
LlavaShared
LLava
}

internal static class LibraryNameExtensions
Expand All @@ -606,9 +606,9 @@ public static string GetLibraryName(this NativeLibraryName name)
{
switch (name)
{
case NativeLibraryName.Llama:
case NativeLibraryName.LLama:
return NativeApi.libraryName;
case NativeLibraryName.LlavaShared:
case NativeLibraryName.LLava:
return NativeApi.llavaLibraryName;
default:
throw new ArgumentOutOfRangeException(nameof(name), name, null);
Expand Down
27 changes: 23 additions & 4 deletions LLama/Native/NativeApi.Load.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Collections.Generic;
using LLama.Abstractions;

namespace LLama.Native
{
Expand Down Expand Up @@ -65,7 +66,7 @@ private static void SetDllImportResolver()
return _loadedLlamaHandle;

// Try to load a preferred library, based on CPU feature detection
_loadedLlamaHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.LLama, out var _);
_loadedLlamaHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.LLama, out _loadedLLamaLibrary);
return _loadedLlamaHandle;
}

Expand All @@ -76,7 +77,7 @@ private static void SetDllImportResolver()
return _loadedLlavaSharedHandle;

// Try to load a preferred library, based on CPU feature detection
_loadedLlavaSharedHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.LLava, out var _);
_loadedLlavaSharedHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.LLava, out _loadedLLavaLibrary);
return _loadedLlavaSharedHandle;
}

Expand All @@ -86,8 +87,26 @@ private static void SetDllImportResolver()
#endif
}

/// <summary>
/// Get the loaded native library. If you are using netstandard2.0, it will always return null.
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
public static INativeLibrary? GetLoadedNativeLibrary(NativeLibraryName name)
{
return name switch
{
NativeLibraryName.LLama => _loadedLLamaLibrary,
NativeLibraryName.LLava => _loadedLLavaLibrary,
_ => throw new ArgumentException($"Library name {name} is not found.")
};
}

internal const string libraryName = "llama";
internal const string llavaLibraryName = "llava_shared";
private const string cudaVersionFile = "version.json";
internal const string llavaLibraryName = "llava_shared";

private static INativeLibrary? _loadedLLamaLibrary = null;
private static INativeLibrary? _loadedLLavaLibrary = null;
}
}

0 comments on commit a86f14d

Please sign in to comment.