diff --git a/src/DllManipulator.ReflectionCache.cs b/src/DllManipulator.ReflectionCache.cs index a7ebdaa..8d89f2e 100644 --- a/src/DllManipulator.ReflectionCache.cs +++ b/src/DllManipulator.ReflectionCache.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Reflection; using System.Threading; using DllManipulator.Internal; @@ -22,6 +23,9 @@ public partial class DllManipulator private static readonly Lazy WriteNativeCrashLogMethod = new Lazy( () => typeof(DllManipulator).GetMethod(nameof(WriteNativeCrashLog), BindingFlags.NonPublic | BindingFlags.Static)); + private static readonly Lazy ThreadsCallingNativesField = new Lazy( + () => typeof(DllManipulator).GetField(nameof(_threadsCallingNatives), BindingFlags.NonPublic | BindingFlags.Static)); + /// /// ReaderWriterLockSlim.EnterReadLock() /// @@ -33,5 +37,17 @@ public partial class DllManipulator /// private static readonly Lazy RwlsExitReadLockMethod = new Lazy( () => typeof(ReaderWriterLockSlim).GetMethod(nameof(ReaderWriterLockSlim.ExitReadLock), BindingFlags.Public | BindingFlags.Instance)); + + /// + /// Thread.get_CurrentThread() + /// + private static readonly Lazy Thread_getCurrentThreadMethod = new Lazy( + () => typeof(Thread).GetProperty(nameof(Thread.CurrentThread), BindingFlags.Public | BindingFlags.Static).GetGetMethod()); + + /// + /// ConcurrentDictionary.TryAdd(Thread key, int value) + /// + private static readonly Lazy ConcurrentDictionaryThreadIntTryAddMethod = new Lazy( + () => typeof(ConcurrentDictionary).GetMethod(nameof(ConcurrentDictionary.TryAdd), BindingFlags.Public | BindingFlags.Instance)); } } diff --git a/src/DllManipulator.cs b/src/DllManipulator.cs index 601dfb9..8123152 100644 --- a/src/DllManipulator.cs +++ b/src/DllManipulator.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Collections.Concurrent; using System.Linq; using System.Reflection; using System.Reflection.Emit; @@ -9,6 +10,7 @@ using System.IO; using UnityEngine; using DllManipulator.Internal; +using System.Threading.Tasks; namespace DllManipulator { @@ -36,6 +38,7 @@ public partial class DllManipulator : MonoBehaviour loadingMode = DllLoadingMode.Lazy, unixDlopenFlags = UnixDlopenFlags.Lazy, threadSafe = false, + waitForThreads = true, crashLogs = false, crashLogsDir = "{assets}/", crashLogsStackTrace = false, @@ -43,6 +46,7 @@ public partial class DllManipulator : MonoBehaviour }; public static TimeSpan? InitializationTime { get; private set; } = null; + public static int AwaitedThreads = 0; //Use with synchronization private static DllManipulatorOptions _options; private static DllManipulator _singletonInstance = null; private static int _unityMainThreadId; @@ -56,6 +60,7 @@ public partial class DllManipulator : MonoBehaviour private static int _nativeFunctionsCount = 0; private static int _createdDelegateTypes = 0; private static int _lastNativeCallIndex = 0; //Use with synchronization + private static readonly ConcurrentDictionary _threadsCallingNatives = new ConcurrentDictionary(); //Used as set, walkaround to the lack of concurrent set in .NET private void OnEnable() @@ -80,14 +85,48 @@ private void OnEnable() Initialize(); } - private void OnApplicationQuit() + private async Task OnApplicationQuit() { - //FIXME: Because we don't wait for other threads to finish, we might be stealing function delegates from under their nose if Unity doesn't happen to close them yet. - //On Preloaded mode this leads to NullReferenceException, but on Lazy mode the DLL and function are just reloaded so we end up with loaded DLL after game exit. + Debug.Log("OnApplicationQuit() before join thread: " + Thread.CurrentThread.ManagedThreadId); + + if (_options.waitForThreads) + { + var threadsToAwait = _threadsCallingNatives.Keys.Where(t => t.ManagedThreadId != _unityMainThreadId).ToList(); + Interlocked.Exchange(ref AwaitedThreads, threadsToAwait.Count); + + await Task.Factory.StartNew(() => + { + while (true) + { + int aliveThreads = 0; + foreach (var t in threadsToAwait) + { + if(t.IsAlive) + { + aliveThreads++; + } + } + + Interlocked.Exchange(ref AwaitedThreads, aliveThreads); + + if (aliveThreads == 0) + { + return; + } + + Thread.Sleep(100); + } + }, TaskCreationOptions.LongRunning); + } + + Debug.Log("OnApplicationQuit() after join thread: " + Thread.CurrentThread.ManagedThreadId); UnloadAll(); + Debug.Log("OnApplicationQuit() after UnloadAll()"); ForgetAllDlls(); + Debug.Log("OnApplicationQuit() after ForgetAllDlls()"); ClearCrashLogs(); + Debug.Log("OnApplicationQuit() after ClearCrashLogs()"); } private static void Initialize() @@ -177,6 +216,25 @@ public static void UnloadAll() } } + public static void AbortAwaitedThreads() + { + var awaitingThreads = _threadsCallingNatives.Keys.Where(t => t.ManagedThreadId != _unityMainThreadId).ToList(); + foreach (var t in awaitingThreads) + { + if (t.IsAlive) + { + try + { + t.Abort(); + } + catch (ThreadAbortException) + { + + } + } + } + } + private static void ForgetAllDlls() { _dlls.Clear(); @@ -287,6 +345,16 @@ private static void GenerateNativeFunctionMockBody(ILGenerator il, int parameter il.DeclareLocal(delegateInvokeMethod.ReturnType); //Local 0: returnValue } + if(_options.waitForThreads) + { + //Store current thread so at application quit event it can be waited for + il.Emit(OpCodes.Ldsfld, ThreadsCallingNativesField.Value); + il.Emit(OpCodes.Call, Thread_getCurrentThreadMethod.Value); + il.Emit(OpCodes.Ldc_I4_0); + il.Emit(OpCodes.Call, ConcurrentDictionaryThreadIntTryAddMethod.Value); + il.Emit(OpCodes.Pop); //Not interested whether actually added element + } + il.Emit(OpCodes.Ldsfld, NativeFunctionLoadLockField.Value); il.Emit(OpCodes.Call, RwlsEnterReadLocKMethod.Value); il.BeginExceptionBlock(); @@ -701,6 +769,7 @@ public class DllManipulatorOptions public DllLoadingMode loadingMode; public UnixDlopenFlags unixDlopenFlags; public bool threadSafe; + public bool waitForThreads; public bool crashLogs; public string crashLogsDir; public bool crashLogsStackTrace; diff --git a/src/Editor/DllManipulatorEditor.cs b/src/Editor/DllManipulatorEditor.cs index 5c7bf41..c5b3817 100644 --- a/src/Editor/DllManipulatorEditor.cs +++ b/src/Editor/DllManipulatorEditor.cs @@ -2,6 +2,7 @@ using System.Linq; using UnityEngine; using UnityEditor; +using System.Threading; namespace DllManipulator { @@ -22,6 +23,9 @@ public class DllManipulatorEditor : Editor private readonly GUIContent THREAD_SAFE_GUI_CONTENT = new GUIContent("Thread safe", "Ensures synchronization required for native calls from any other than Unity main thread. Overhead might be few times higher, with uncontended locks.\n\n" + "Only in Preloaded mode."); + private readonly GUIContent WAIT_FOR_THREADS_GUI_CONTENT = new GUIContent("Wait for threads", + "Whether to wait before unloading DLLs for all threads that use native functions. Not doing so may cause some weirdness at stopping game.\n\n" + + "Quite small overhead."); private readonly GUIContent CRASH_LOGS_GUI_CONTENT = new GUIContent("Crash logs", "Logs each native call to file. In case of crash or hang caused by native function, you can than see what function was that, along with arguments and, optionally, stack trace.\n\n" + "In multi-threaded scenario there will be one file for each thread and you'll have to guess the right one (call index will be a hint).\n\n" + @@ -158,6 +162,19 @@ public override void OnInspectorGUI() var time = DllManipulator.InitializationTime.Value; EditorGUILayout.LabelField($"Initialized in: {(int)time.TotalSeconds}.{time.Milliseconds.ToString("D3")}s"); } + + var awaitedTheads = Interlocked.CompareExchange(ref DllManipulator.AwaitedThreads, 0, 0); + if(awaitedTheads != 0) + { + EditorGUILayout.Space(); + EditorGUILayout.LabelField($"Waiting for threads: {awaitedTheads}"); + if(GUILayout.Button("Abort awaited threads")) + { + DllManipulator.AbortAwaitedThreads(); + } + + EditorUtility.SetDirty(target); //Cause this editor to reprint next frame (or quite soon) + } } private void DrawOptions(DllManipulatorOptions options) @@ -186,6 +203,16 @@ private void DrawOptions(DllManipulatorOptions options) options.threadSafe = EditorGUILayout.Toggle(THREAD_SAFE_GUI_CONTENT, options.threadSafe); GUI.enabled = guiEnabledStack.Pop(); + if(options.threadSafe) + { + var prevIndent = EditorGUI.indentLevel; + + EditorGUI.indentLevel += 1; + options.waitForThreads = EditorGUILayout.Toggle(WAIT_FOR_THREADS_GUI_CONTENT, options.waitForThreads); + + EditorGUI.indentLevel = prevIndent; + } + options.crashLogs = EditorGUILayout.Toggle(CRASH_LOGS_GUI_CONTENT, options.crashLogs); if (options.crashLogs)