From 596d244beb56426d0d98415bd452f0321e233044 Mon Sep 17 00:00:00 2001 From: Chris Patterson Date: Wed, 3 Apr 2024 15:17:13 -0500 Subject: [PATCH] Fixed #5015 - States and Events from state machine base classes are now properly initialized --- .../MassTransitStateMachine.cs | 20 ++- .../SagaStateMachineTests/BaseClass_Specs.cs | 123 ++++++++++++++++++ 2 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 tests/MassTransit.Tests/SagaStateMachineTests/BaseClass_Specs.cs diff --git a/src/MassTransit/SagaStateMachine/MassTransitStateMachine.cs b/src/MassTransit/SagaStateMachine/MassTransitStateMachine.cs index 8b175afad7e..c5d97e548fc 100644 --- a/src/MassTransit/SagaStateMachine/MassTransitStateMachine.cs +++ b/src/MassTransit/SagaStateMachine/MassTransitStateMachine.cs @@ -1978,8 +1978,7 @@ IEnumerable GetStateMachineProperties() bool TryGetBackingField(PropertyInfo property, out FieldInfo backingField) { - _backingFields ??= GetType() - .GetFields(BindingFlags.NonPublic | BindingFlags.Instance) + _backingFields ??= GetBackingFields(GetType()) .Where(field => field.Attributes.HasFlag(FieldAttributes.Private) && field.Attributes.HasFlag(FieldAttributes.InitOnly) && @@ -1997,6 +1996,23 @@ bool TryGetBackingField(PropertyInfo property, out FieldInfo backingField) return backingField != null; } + static IEnumerable GetBackingFields(Type type) + { + while (true) + { + foreach (var fieldInfo in type.GetFields(BindingFlags.NonPublic | BindingFlags.Instance)) + yield return fieldInfo; + + if (type.BaseType == null) + break; + + if (type.BaseType.IsGenericType && type.BaseType.GetGenericTypeDefinition() == typeof(MassTransitStateMachine<>)) + break; + + type = type.BaseType; + } + } + void InitializeState(MassTransitStateMachine stateMachine, PropertyInfo property, StateMachineState state) { if (property.CanWrite) diff --git a/tests/MassTransit.Tests/SagaStateMachineTests/BaseClass_Specs.cs b/tests/MassTransit.Tests/SagaStateMachineTests/BaseClass_Specs.cs new file mode 100644 index 00000000000..fcd10695a7e --- /dev/null +++ b/tests/MassTransit.Tests/SagaStateMachineTests/BaseClass_Specs.cs @@ -0,0 +1,123 @@ +namespace MassTransit.Tests.SagaStateMachineTests +{ + using System.Threading.Tasks; + using BaseStateMachineTestSubjects; + using MassTransit.Testing; + using Microsoft.Extensions.DependencyInjection; + using NUnit.Framework; + + + [TestFixture] + public class Using_a_base_state_machine + { + [Test] + public async Task Should_initialize_all_states_and_events() + { + await using var provider = new ServiceCollection() + .AddMassTransitTestHarness(x => + { + x.AddSagaStateMachine(); + }) + .BuildServiceProvider(true); + + var harness = await provider.StartTestHarness(); + + var id = NewId.NextGuid(); + + await harness.Bus.Publish(new HappyEvent(id)); + + Assert.That(await harness.Consumed.Any()); + + await harness.Bus.Publish(new EndItAllEvent(id)); + + Assert.That(await harness.Consumed.Any()); + } + } + + + namespace BaseStateMachineTestSubjects + { + using System; + + + public class HappyEvent + { + public HappyEvent(Guid correlationId) + { + CorrelationId = correlationId; + } + + public Guid CorrelationId { get; set; } + } + + + public class GoLuckyEvent + { + public GoLuckyEvent(Guid correlationId) + { + CorrelationId = correlationId; + } + + public Guid CorrelationId { get; set; } + } + + + public class EndItAllEvent + { + public EndItAllEvent(Guid correlationId) + { + CorrelationId = correlationId; + } + + public Guid CorrelationId { get; set; } + } + + + public class CommonStateMachine : + MassTransitStateMachine + where T : class, SagaStateMachineInstance + { + // + // ReSharper disable UnassignedGetOnlyAutoProperty + public State Happy { get; } + public State GoLucky { get; } + + public Event OnHappy { get; } + public Event OnGoLucky { get; } + } + + + public class HappyGoLuckyState : + SagaStateMachineInstance + { + public string CurrentState { get; set; } + public Guid CorrelationId { get; set; } + } + + + public class HappyGoLuckyStateMachine : + CommonStateMachine + { + public HappyGoLuckyStateMachine() + { + InstanceState(x => x.CurrentState); + + Initially( + When(OnHappy) + .TransitionTo(Happy), + When(OnGoLucky) + .TransitionTo(GoLucky)); + + During(Happy, GoLucky, + When(OnEndItAll) + .TransitionTo(Finished)); + + SetCompletedWhenFinalized(); + } + + public State Finished { get; } + + public Event OnEndItAll { get; } + } + } +}