Skip to content

Commit

Permalink
Refactor worker pattern (#545)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastienros authored May 3, 2024
1 parent 5f82900 commit 99b3995
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 76 deletions.
90 changes: 90 additions & 0 deletions src/YesSql.Core/Data/WorkDispatcher.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
using System;
using System.Collections.Concurrent;
using System.Threading.Tasks;

#nullable enable

namespace YesSql.Data;

internal sealed class WorkDispatcher<TKey, TValue> where TKey : notnull
{
private readonly ConcurrentDictionary<TKey, Task<TValue?>> _workers = new();

public async Task<TValue?> ScheduleAsync(TKey key, Func<TKey, Task<TValue?>> valueFactory)
{
ArgumentNullException.ThrowIfNull(key);

while (true)
{
if (_workers.TryGetValue(key, out var task))
{
return await task;
}

// This is the task that we'll return to all waiters. We'll complete it when the factory is complete
var tcs = new TaskCompletionSource<TValue?>(TaskCreationOptions.RunContinuationsAsynchronously);

if (_workers.TryAdd(key, tcs.Task))
{
try
{
var value = await valueFactory(key);
tcs.TrySetResult(value);
return await tcs.Task;
}
catch (Exception ex)
{
// Make sure all waiters see the exception
tcs.SetException(ex);

throw;
}
finally
{
// We remove the entry if the factory failed so it's not a permanent failure
// and future gets can retry (this could be a pluggable policy)
_workers.TryRemove(key, out _);
}
}
}
}

public async Task<TValue?> ScheduleAsync<TState>(TKey key, TState state, Func<TKey, TState, Task<TValue?>> valueFactory)
{
ArgumentNullException.ThrowIfNull(key);

while (true)
{
if (_workers.TryGetValue(key, out var task))
{
return await task;
}

// This is the task that we'll return to all waiters. We'll complete it when the factory is complete
var tcs = new TaskCompletionSource<TValue?>(TaskCreationOptions.RunContinuationsAsynchronously);

if (_workers.TryAdd(key, tcs.Task))
{
try
{
var value = await valueFactory(key, state);
tcs.TrySetResult(value);
return await tcs.Task;
}
catch (Exception ex)
{
// Make sure all waiters see the exception
tcs.SetException(ex);

throw;
}
finally
{
// We remove the entry if the factory failed so it's not a permanent failure
// and future gets can retry (this could be a pluggable policy)
_workers.TryRemove(key, out _);
}
}
}
}
}
66 changes: 41 additions & 25 deletions src/YesSql.Core/Data/WorkerQueryKey.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,44 @@ namespace YesSql.Data
/// <summary>
/// An instance of <see cref="WorkerQueryKey"/> represents the state of <see cref="WorkerQueryKey"/>.
/// </summary>
public readonly struct WorkerQueryKey : IEquatable<WorkerQueryKey>
public class WorkerQueryKey : IEquatable<WorkerQueryKey>
{
private readonly string _prefix;
private readonly long _id;
private readonly long[] _ids;
private readonly Dictionary<string, object> _parameters;
private readonly int _hashcode;
private readonly int _hashCode;

public WorkerQueryKey(string prefix, long[] ids)
{
if (prefix == null)
{
throw new ArgumentNullException(nameof(prefix));
}

if (ids == null)
{
throw new ArgumentNullException(nameof(ids));
}
ArgumentNullException.ThrowIfNull(prefix);
ArgumentNullException.ThrowIfNull(ids);

_prefix = prefix;
_parameters = null;
_ids = ids;
_hashcode = 0;
_hashcode = BuildHashCode();
_hashCode = BuildHashCode();
}

public WorkerQueryKey(string prefix, long id)
{
ArgumentNullException.ThrowIfNull(prefix);

_prefix = prefix;
_parameters = null;
_id = id;
_hashCode = BuildHashCode();
}

public WorkerQueryKey(string prefix, Dictionary<string, object> parameters)
{
ArgumentNullException.ThrowIfNull(prefix);
ArgumentNullException.ThrowIfNull(parameters);

_prefix = prefix;
_parameters = parameters;
_ids = null;
_hashcode = 0;
_hashcode = BuildHashCode();
_hashCode = BuildHashCode();
}

/// <inheritdoc />
Expand Down Expand Up @@ -75,45 +80,46 @@ public bool Equals(WorkerQueryKey other)

private int BuildHashCode()
{
var combinedHash = 5381;
combinedHash = ((combinedHash << 5) + combinedHash) ^ _prefix.GetHashCode();
var hashCode = new HashCode();

hashCode.Add(_prefix);

if (_parameters != null)
{
foreach (var parameter in _parameters)
{
if (parameter.Key != null)
{
combinedHash = ((combinedHash << 5) + combinedHash) ^ parameter.Key.GetHashCode();
hashCode.Add(parameter.Key);
}

if (parameter.Value != null)
{
combinedHash = ((combinedHash << 5) + combinedHash) ^ parameter.Value.GetHashCode();
hashCode.Add(parameter.Value);
}
}

return combinedHash;
}

if (_ids != null)
{
foreach (var id in _ids)
{
combinedHash = ((combinedHash << 5) + combinedHash) ^ (int)id;
hashCode.Add(id);
}
}

return combinedHash;
if (_id != 0)
{
hashCode.Add(_id);
}

return default;
return hashCode.ToHashCode();
}

/// <inheritdoc />
public override int GetHashCode()
{
return _hashcode;
return _hashCode;
}

private static bool SameParameters(Dictionary<string, object> values1, Dictionary<string, object> values2)
Expand Down Expand Up @@ -181,5 +187,15 @@ private static bool SameIds(long[] values1, long[] values2)

return true;
}

public static bool operator ==(WorkerQueryKey left, WorkerQueryKey right)
{
return left.Equals(right);
}

public static bool operator !=(WorkerQueryKey left, WorkerQueryKey right)
{
return !(left == right);
}
}
}
10 changes: 5 additions & 5 deletions src/YesSql.Core/Services/DefaultQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ public async Task<int> CountAsync()

try
{
return await _session._store.ProduceAsync(key, static (state) =>
return await _session._store.ProduceAsync(key, static (key, state) =>
{
var logger = state.Session._store.Configuration.Logger;

Expand Down Expand Up @@ -1221,7 +1221,7 @@ protected async Task<T> FirstOrDefaultImpl()
_query._queryState._sqlBuilder.Selector("*");
var sql = _query._queryState._sqlBuilder.ToSqlString();
var key = new WorkerQueryKey(sql, _query._queryState._sqlBuilder.Parameters);
return (await _query._session._store.ProduceAsync(key, static (state) =>
return (await _query._session._store.ProduceAsync(key, static (key, state) =>
{
var logger = state.Query._session._store.Configuration.Logger;

Expand All @@ -1239,7 +1239,7 @@ protected async Task<T> FirstOrDefaultImpl()
_query._queryState._sqlBuilder.Selector(_query._queryState._documentTable, "*", _query._queryState._store.Configuration.Schema);
var sql = _query._queryState._sqlBuilder.ToSqlString();
var key = new WorkerQueryKey(sql, _query._queryState._sqlBuilder.Parameters);
var documents = await _query._session._store.ProduceAsync(key, static (state) =>
var documents = await _query._session._store.ProduceAsync(key, static (key, state) =>
{
var logger = state.Query._session._store.Configuration.Logger;

Expand Down Expand Up @@ -1326,7 +1326,7 @@ internal async Task<IEnumerable<T>> ListImpl()
var sql = sqlBuilder.ToSqlString();
var key = new WorkerQueryKey(sql, _query._queryState._sqlBuilder.Parameters);

return await _query._session._store.ProduceAsync(key, static (state) =>
return await _query._session._store.ProduceAsync(key, static (key, state) =>
{
var logger = state.Query._session._store.Configuration.Logger;

Expand Down Expand Up @@ -1356,7 +1356,7 @@ internal async Task<IEnumerable<T>> ListImpl()

var key = new WorkerQueryKey(sql, sqlBuilder.Parameters);

var documents = await _query._session._store.ProduceAsync(key, static (state) =>
var documents = await _query._session._store.ProduceAsync(key, static (key, state) =>
{
var logger = state.Query._session._store.Configuration.Logger;

Expand Down
6 changes: 3 additions & 3 deletions src/YesSql.Core/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,11 @@ private async Task<Document> GetDocumentByIdAsync(long id, string collection)
var documentTable = Store.Configuration.TableNameConvention.GetDocumentTable(collection);

var command = "select * from " + _dialect.QuoteForTableName(_tablePrefix + documentTable, Store.Configuration.Schema) + " where " + _dialect.QuoteForColumnName("Id") + " = @Id";
var key = new WorkerQueryKey(nameof(GetDocumentByIdAsync), new[] { id });
var key = new WorkerQueryKey(nameof(GetDocumentByIdAsync), id);

try
{
var result = await _store.ProduceAsync(key, (state) =>
var result = await _store.ProduceAsync(key, (key, state) =>
{
var logger = state.Store.Configuration.Logger;

Expand Down Expand Up @@ -506,7 +506,7 @@ public async Task<IEnumerable<T>> GetAsync<T>(long[] ids, string collection = nu
var key = new WorkerQueryKey(nameof(GetAsync), ids);
try
{
var documents = await _store.ProduceAsync(key, static (state) =>
var documents = await _store.ProduceAsync(key, static (key, state) =>
{
var logger = state.Store.Configuration.Logger;

Expand Down
52 changes: 9 additions & 43 deletions src/YesSql.Core/Store.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public class Store : IStore

internal readonly ConcurrentDictionary<long, QueryState> CompiledQueries = new();

private readonly WorkDispatcher<WorkerQueryKey, object> _dispatcher = new();

internal const int SmallBufferSize = 128;
internal const int MediumBufferSize = 512;
internal const int LargeBufferSize = 1024;
Expand Down Expand Up @@ -289,55 +291,19 @@ public IStore RegisterScopedIndexes(IEnumerable<Type> indexProviders)
/// <param name="key">A key identifying the running work.</param>
/// <param name="work">A function containing the logic to execute.</param>
/// <returns>The result of the work.</returns>
internal async Task<T> ProduceAsync<T, TState>(WorkerQueryKey key, Func<TState, Task<T>> work, TState state)
internal Task<T> ProduceAsync<T, TState>(WorkerQueryKey key, Func<WorkerQueryKey, TState, Task<T>> work, TState state)
{
if (!Configuration.QueryGatingEnabled)
{
return await work(state);
return work(key, state);
}

object content = null;

while (content == null)
{
// Is there any query already processing the ?
if (!Workers.TryGetValue(key, out var result))
{
// Multiple threads can potentially reach this point which is fine
// c.f. https://blogs.msdn.microsoft.com/seteplia/2018/10/01/the-danger-of-taskcompletionsourcet-class/
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);

Workers.TryAdd(key, tcs.Task);

try
{
// The current worker is processed
content = await work(state);
}
catch
{
// An exception occurred in the main worker, we broadcast the null value
content = null;
throw;
}
finally
{
// Remove the worker task before setting the result.
// If the result is null, other threads would potentially
// acquire it otherwise.
Workers.TryRemove(key, out _);
return ProduceAwaitedAsync(key, work, state);
}

// Notify all other awaiters to return the result
tcs.TrySetResult(content);
}
}
else
{
// Another worker is already running, wait for it to finish and reuse the results.
// This value can be null if the worker failed, in this case the loop will run again.
content = await result;
}
}
internal async Task<T> ProduceAwaitedAsync<T, TState>(WorkerQueryKey key, Func<WorkerQueryKey, TState, Task<T>> work, TState state)
{
var content = await _dispatcher.ScheduleAsync(key, state, async (key, state) => await work(key, state));

return (T)content;
}
Expand Down

0 comments on commit 99b3995

Please sign in to comment.