Skip to content

Commit

Permalink
fix: race condition between renders, FindComponents, and WaitForHelper
Browse files Browse the repository at this point in the history
fixes bUnit-dev#577.

The problem is that FindComponents traverses down the render tree when invoked, and this ensures that
no renders happens while it does so, without using locks like previous, which could result in deadlocks.

fix: aways wrap FindComponentsInternal in Dispatcher

fix: optimize wait for logging

fix: ensure failure tasks in WaitForHelper run on Renderer schedular
  • Loading branch information
egil committed May 21, 2022
1 parent 6527ad4 commit bf35912
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 162 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ All notable changes to **bUnit** will be documented in this file. The project ad

## [Unreleased]

### Fixes

- A race condition existed between `WaitForState` / `WaitForAssertion` and `FindComponents`, if the first used the latter. Reported by [@rmihael](https://github.com/rmihael), [@SviatoslavK](https://github.com/SviatoslavK), and [@RaphaelMarcouxCTRL](https://github.com/RaphaelMarcouxCTRL). Fixed by [@egil](https://github.com/egil) and [@linkdotnet](https://github.com/linkdotnet).

## [1.8.15] - 2022-05-19

### Added
Expand Down
209 changes: 109 additions & 100 deletions src/bunit.core/Extensions/WaitForHelpers/WaitForHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ namespace Bunit.Extensions.WaitForHelpers;
/// </summary>
public abstract class WaitForHelper<T> : IDisposable
{
private readonly object lockObject = new();
private readonly Timer timer;
private readonly TaskCompletionSource<T> checkPassedCompletionSource;
private readonly Func<(bool CheckPassed, T Content)> completeChecker;
private readonly IRenderedFragmentBase renderedFragment;
Expand Down Expand Up @@ -40,143 +38,154 @@ public abstract class WaitForHelper<T> : IDisposable
/// </summary>
public Task<T> WaitTask { get; }


/// <summary>
/// Initializes a new instance of the <see cref="WaitForHelper{T}"/> class.
/// </summary>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Performance", "CA1849:Call async methods when in an async method", Justification = "Using x.Result inside a ContinueWith is safe.")]
protected WaitForHelper(IRenderedFragmentBase renderedFragment, Func<(bool CheckPassed, T Content)> completeChecker, TimeSpan? timeout = null)
protected WaitForHelper(
IRenderedFragmentBase renderedFragment,
Func<(bool CheckPassed, T Content)> completeChecker,
TimeSpan? timeout = null)
{
this.renderedFragment = renderedFragment ?? throw new ArgumentNullException(nameof(renderedFragment));
this.completeChecker = completeChecker ?? throw new ArgumentNullException(nameof(completeChecker));

logger = renderedFragment.Services.CreateLogger<WaitForHelper<T>>();
checkPassedCompletionSource = new TaskCompletionSource<T>();
WaitTask = CreateWaitTask(renderedFragment, timeout);

var renderer = renderedFragment.Services.GetRequiredService<ITestRenderer>();
var renderException = renderer
.UnhandledException
.ContinueWith(x => Task.FromException<T>(x.Result), CancellationToken.None, TaskContinuationOptions.OnlyOnRanToCompletion, TaskScheduler.Current)
.Unwrap();
InitializeWaiting();
}

checkPassedCompletionSource = new TaskCompletionSource<T>();
WaitTask = Task.WhenAny(checkPassedCompletionSource.Task, renderException).Unwrap();
/// <summary>
/// Disposes the wait helper and cancels the any ongoing waiting, if it is not
/// already in one of the other completed states.
/// </summary>
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}

timer = new Timer(OnTimeout, this, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
/// <summary>
/// Disposes of the wait task and related logic.
/// </summary>
/// <remarks>
/// The disposing parameter should be false when called from a finalizer, and true when called from the
/// <see cref="Dispose()"/> method. In other words, it is true when deterministically called and false when non-deterministically called.
/// </remarks>
/// <param name="disposing">Set to true if called from <see cref="Dispose()"/>, false if called from a finalizer.f.</param>
protected virtual void Dispose(bool disposing)
{
if (isDisposed || !disposing)
return;

isDisposed = true;
checkPassedCompletionSource.TrySetCanceled();
renderedFragment.OnAfterRender -= OnAfterRender;
logger.LogWaiterDisposed(renderedFragment.ComponentId);
}

private void InitializeWaiting()
{
if (!WaitTask.IsCompleted)
{
var renderCountAtSubscribeTime = renderedFragment.RenderCount;

// Before subscribing to renderedFragment.OnAfterRender,
// we need to make sure that the desired state has not already been reached.
OnAfterRender(this, EventArgs.Empty);
this.renderedFragment.OnAfterRender += OnAfterRender;
OnAfterRender(this, EventArgs.Empty);
StartTimer(timeout);

SubscribeToOnAfterRender();

// If the render count from before subscribing has changes
// till now, we need to do trigger another check, since
// the render may have happened asynchronously and before
// the subscription was set up.
if (renderCountAtSubscribeTime < renderedFragment.RenderCount)
{
OnAfterRender(this, EventArgs.Empty);
}
}
}

private void StartTimer(TimeSpan? timeout)
private Task<T> CreateWaitTask(IRenderedFragmentBase renderedFragment, TimeSpan? timeout)
{
if (isDisposed)
return;
var renderer = renderedFragment.Services.GetRequiredService<ITestRenderer>();

lock (lockObject)
// Two to failure conditions, that the renderer captures an unhandled
// exception from a component or itself, or that the timeout is reached,
// are executed on the renderes schedular, to ensure that OnAfterRender
// and the continuations does not happen at the same time.
var failureTask = renderer.Dispatcher.InvokeAsync(() =>
{
if (isDisposed)
return;

timer.Change(GetRuntimeTimeout(timeout), Timeout.InfiniteTimeSpan);
}
var taskScheduler = TaskScheduler.FromCurrentSynchronizationContext();

var renderException = renderer
.UnhandledException
.ContinueWith(
x => Task.FromException<T>(x.Result),
CancellationToken.None,
TaskContinuationOptions.OnlyOnRanToCompletion | TaskContinuationOptions.ExecuteSynchronously,
taskScheduler);

var timeoutTask = Task.Delay(GetRuntimeTimeout(timeout))
.ContinueWith(
x =>
{
logger.LogWaiterTimedOut(renderedFragment.ComponentId);
return Task.FromException<T>(new WaitForFailedException(TimeoutErrorMessage, capturedException));
},
CancellationToken.None,
TaskContinuationOptions.OnlyOnRanToCompletion | TaskContinuationOptions.ExecuteSynchronously,
taskScheduler);

return Task.WhenAny(renderException, timeoutTask).Unwrap();
}).Unwrap();

return Task.WhenAny(failureTask, checkPassedCompletionSource.Task).Unwrap();
}

private void OnAfterRender(object? sender, EventArgs args)
{
if (isDisposed)
if (isDisposed || WaitTask.IsCompleted)
return;

lock (lockObject)
try
{
if (isDisposed)
return;
logger.LogCheckingWaitCondition(renderedFragment.ComponentId);

try
var checkResult = completeChecker();
if (checkResult.CheckPassed)
{
logger.LogCheckingWaitCondition(renderedFragment.ComponentId);

var checkResult = completeChecker();
if (checkResult.CheckPassed)
{
checkPassedCompletionSource.TrySetResult(checkResult.Content);
logger.LogCheckCompleted(renderedFragment.ComponentId);
Dispose();
}
else
{
logger.LogCheckFailed(renderedFragment.ComponentId);
}
checkPassedCompletionSource.TrySetResult(checkResult.Content);
logger.LogCheckCompleted(renderedFragment.ComponentId);
Dispose();
}
catch (Exception ex)
else
{
capturedException = ex;
logger.LogCheckThrow(renderedFragment.ComponentId, ex);

if (StopWaitingOnCheckException)
{
checkPassedCompletionSource.TrySetException(new WaitForFailedException(CheckThrowErrorMessage, capturedException));
Dispose();
}
logger.LogCheckFailed(renderedFragment.ComponentId);
}
}
}

private void OnTimeout(object? state)
{
if (isDisposed)
return;

lock (lockObject)
catch (Exception ex)
{
if (isDisposed)
return;

logger.LogWaiterTimedOut(renderedFragment.ComponentId);
capturedException = ex;
logger.LogCheckThrow(renderedFragment.ComponentId, ex);

checkPassedCompletionSource.TrySetException(new WaitForFailedException(TimeoutErrorMessage, capturedException));

Dispose();
if (StopWaitingOnCheckException)
{
checkPassedCompletionSource.TrySetException(new WaitForFailedException(CheckThrowErrorMessage, capturedException));
Dispose();
}
}
}

/// <summary>
/// Disposes the wait helper and cancels the any ongoing waiting, if it is not
/// already in one of the other completed states.
/// </summary>
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}

/// <summary>
/// Disposes of the wait task and related logic.
/// </summary>
/// <remarks>
/// The disposing parameter should be false when called from a finalizer, and true when called from the
/// <see cref="Dispose()"/> method. In other words, it is true when deterministically called and false when non-deterministically called.
/// </remarks>
/// <param name="disposing">Set to true if called from <see cref="Dispose()"/>, false if called from a finalizer.f.</param>
protected virtual void Dispose(bool disposing)
private void SubscribeToOnAfterRender()
{
if (isDisposed || !disposing)
return;

lock (lockObject)
{
if (isDisposed)
return;

isDisposed = true;
renderedFragment.OnAfterRender -= OnAfterRender;
timer.Dispose();
checkPassedCompletionSource.TrySetCanceled();
logger.LogWaiterDisposed(renderedFragment.ComponentId);
}
// There might not be a need to subscribe if the WaitTask has already
// been completed, perhaps due to an unhandled exception from the
// renderer or from the initial check by the checker.
if (!isDisposed && !WaitTask.IsCompleted)
renderedFragment.OnAfterRender += OnAfterRender;
}

private static TimeSpan GetRuntimeTimeout(TimeSpan? timeout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,50 @@ private static readonly Action<ILogger, int, Exception> CheckThrow
= LoggerMessage.Define<int>(LogLevel.Debug, new EventId(20, "OnTimeout"), "The waiter for component {Id} disposed.");

internal static void LogCheckingWaitCondition<T>(this ILogger<WaitForHelper<T>> logger, int componentId)
=> CheckingWaitCondition(logger, componentId, null);
{
if (logger.IsEnabled(LogLevel.Debug))
{
CheckingWaitCondition(logger, componentId, null);
}
}

internal static void LogCheckCompleted<T>(this ILogger<WaitForHelper<T>> logger, int componentId)
=> CheckCompleted(logger, componentId, null);
{
if (logger.IsEnabled(LogLevel.Debug))
{
CheckCompleted(logger, componentId, null);
}
}

internal static void LogCheckFailed<T>(this ILogger<WaitForHelper<T>> logger, int componentId)
=> CheckFailed(logger, componentId, null);
{
if (logger.IsEnabled(LogLevel.Debug))
{
CheckFailed(logger, componentId, null);
}
}

internal static void LogCheckThrow<T>(this ILogger<WaitForHelper<T>> logger, int componentId, Exception exception)
=> CheckThrow(logger, componentId, exception);
{
if (logger.IsEnabled(LogLevel.Debug))
{
CheckThrow(logger, componentId, exception);
}
}

internal static void LogWaiterTimedOut<T>(this ILogger<WaitForHelper<T>> logger, int componentId)
=> WaiterTimedOut(logger, componentId, null);
{
if (logger.IsEnabled(LogLevel.Debug))
{
WaiterTimedOut(logger, componentId, null);
}
}

internal static void LogWaiterDisposed<T>(this ILogger<WaitForHelper<T>> logger, int componentId)
=> WaiterDisposed(logger, componentId, null);
{
if (logger.IsEnabled(LogLevel.Debug))
{
WaiterDisposed(logger, componentId, null);
}
}
}
Loading

0 comments on commit bf35912

Please sign in to comment.