Created
March 31, 2022 05:23
-
-
Save nathan815/426c0d5dabc467f6a0a12dc900218df0 to your computer and use it in GitHub Desktop.
Kotlin state machine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package me.nathanjohnson | |
import android.util.Log | |
import kotlin.reflect.KClass | |
/** | |
* Transition from one state to another on a specific event. | |
*/ | |
data class StateTransition<StateT, ResourceT>( | |
val event: KClass<*>, | |
val from: StateT, | |
val to: StateT, | |
/** Additional check before the transition occurs. Returns false to stop this transition. */ | |
val guard: ((event: StateMachine.Event, resource: ResourceT) -> Boolean)? = null, | |
/** Action (side effect) when this transition occurs. Optionally returns a modified resource object. */ | |
val action: ((event: StateMachine.Event, resource: ResourceT) -> ResourceT?)? = null, | |
) | |
/** | |
* Represents a finite state machine (FSM). | |
* | |
* The only way to move from one state to another is by triggering an event, | |
* which in turn causes a transition from the current state to another. | |
* | |
* @param StateT Type of state being maintained | |
* @param ResourceT Type of resource for which we are maintaining state of | |
*/ | |
class StateMachine<StateT, ResourceT>( | |
initialState: StateT, | |
setup: (StateMachine<StateT, ResourceT>.() -> Unit) | |
) { | |
open class Event | |
enum class TriggerResult { | |
NoTransition, | |
StoppedByGuard, | |
StateUpdated, | |
} | |
data class TriggerOutput<ResourceT>(val result: TriggerResult, val resource: ResourceT? = null) | |
private var transitions: List<StateTransition<StateT, ResourceT>> = listOf() | |
private var onTransitionCallback: ( | |
transition: StateTransition<StateT, ResourceT>, | |
resource: ResourceT | |
) -> ResourceT = { _, res -> res } | |
var currentState: StateT = initialState | |
init { | |
this.setup() | |
} | |
/** | |
* Function passed to this is called whenever a transition occurs. | |
* Should return resource object with any needed modifications. | |
* */ | |
fun onTransition( | |
func: (transition: StateTransition<StateT, ResourceT>, resource: ResourceT) -> ResourceT = { _, res -> res } | |
) { | |
onTransitionCallback = func | |
} | |
fun addTransitions(vararg ts: StateTransition<StateT, ResourceT>) { | |
transitions = transitions + ts.toList() | |
} | |
/** | |
* Trigger an event in the state machine. | |
* | |
* If the guard function for the matching state transition returns false, the | |
* transition will be stopped. | |
* | |
* @param event An event to trigger | |
*/ | |
fun trigger(event: Event, resource: ResourceT): TriggerOutput<ResourceT> { | |
Log.i(TAG, "Trigger event $event - Current state is $currentState") | |
val transition = transitions.find { it.event == event::class && it.from == currentState } | |
if (transition == null) { | |
Log.w(TAG, "No transition exists for $event in current state $currentState") | |
return TriggerOutput(TriggerResult.NoTransition) | |
} | |
if (transition.guard?.invoke(event, resource) == false) { | |
Log.i( | |
TAG, | |
"Transition stopped by guard. Transition: $transition, Event: $event, Resource: $resource" | |
) | |
return TriggerOutput(TriggerResult.StoppedByGuard) | |
} | |
currentState = transition.to | |
val resourceFromAction = transition.action?.invoke(event, resource) | |
val finalResource = onTransitionCallback(transition, resourceFromAction ?: resource) | |
return TriggerOutput( | |
result = TriggerResult.StateUpdated, | |
resource = finalResource | |
) | |
} | |
companion object { | |
private val TAG = StateMachine::class.java.simpleName | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.justlight.sunflower.util | |
import org.junit.Assert.assertEquals | |
import org.junit.Assert.assertNull | |
import org.spekframework.spek2.Spek | |
import org.spekframework.spek2.style.specification.describe | |
enum class TestState { | |
State1, | |
State2, | |
State3, | |
State4; | |
companion object { | |
val INITIAL = State1 | |
} | |
} | |
data class TestResource(val state: TestState, val number: Int = 0) | |
/** | |
* Contains events for the state machine. | |
* | |
* Using a class for each event allows passing data to actions and guards since the event is passed to them. | |
* If no parameters need to be sent with events, a simple enum could be used instead. | |
*/ | |
object TestEvents { | |
sealed class BaseEvent : StateMachine.Event() | |
object BasicEvent : BaseEvent() | |
object AnotherBasicEvent : BaseEvent() | |
class EventWithNumber(val someNumber: Int) : BaseEvent() | |
} | |
class StateMachineTest : Spek({ | |
describe("trigger") { | |
describe("no matching transition") { | |
it("does nothing") { | |
val resource = TestResource(state = TestState.INITIAL) | |
val stateMachine = StateMachine<TestState, TestResource>(TestState.INITIAL) { | |
onTransition { transition, res -> res.copy(state = transition.to) } | |
addTransitions( | |
StateTransition( | |
event = TestEvents.BasicEvent::class, | |
from = TestState.State2, | |
to = TestState.State3, | |
), | |
) | |
} | |
val output = stateMachine.trigger(TestEvents.BasicEvent, resource) | |
assertEquals(TestState.INITIAL, stateMachine.currentState) | |
assertEquals(StateMachine.TriggerResult.NoTransition, output.result) | |
assertNull(output.resource) | |
} | |
} | |
describe("transition with no action or guard") { | |
it("changes state") { | |
val resource = TestResource(state = TestState.INITIAL) | |
val stateMachine = StateMachine<TestState, TestResource>(TestState.INITIAL) { | |
onTransition { transition, res -> res.copy(state = transition.to) } | |
addTransitions( | |
StateTransition( | |
event = TestEvents.BasicEvent::class, | |
from = TestState.State1, | |
to = TestState.State2, | |
), | |
) | |
} | |
val output = stateMachine.trigger(TestEvents.BasicEvent, resource) | |
assertEquals(TestState.State2, stateMachine.currentState) | |
assertEquals(resource.copy(state = TestState.State2), output.resource) | |
} | |
} | |
describe("transition with action") { | |
it("changes state, executes action, and returns the output in TriggerOutput object") { | |
val resource = TestResource(state = TestState.INITIAL, number = 0) | |
val stateMachine = StateMachine<TestState, TestResource>(TestState.INITIAL) { | |
onTransition { transition, res -> res.copy(state = transition.to) } | |
addTransitions( | |
StateTransition( | |
event = TestEvents.BasicEvent::class, | |
from = TestState.State1, | |
to = TestState.State2, | |
), | |
StateTransition( | |
event = TestEvents.EventWithNumber::class, | |
from = TestState.State2, | |
to = TestState.State3, | |
action = { e: StateMachine.Event, r: TestResource -> | |
r.copy(number = (e as TestEvents.EventWithNumber).someNumber) | |
} | |
), | |
StateTransition( | |
event = TestEvents.AnotherBasicEvent::class, | |
from = TestState.State3, | |
to = TestState.State4, | |
action = { _, _ -> null } | |
), | |
StateTransition( | |
event = TestEvents.BasicEvent::class, | |
from = TestState.State4, | |
to = TestState.State1 | |
) | |
) | |
} | |
// 1 to 2 | |
val output0 = stateMachine.trigger(TestEvents.BasicEvent, resource) | |
assertEquals(TestState.State2, stateMachine.currentState) | |
assertEquals(resource.copy(state = TestState.State2), output0.resource) | |
// 2 to 3 | |
val output1 = | |
stateMachine.trigger(TestEvents.EventWithNumber(someNumber = 4815), resource) | |
assertEquals(TestState.State3, stateMachine.currentState) | |
assertEquals( | |
"action should run and return modified resource", | |
resource.copy(state = TestState.State3, number = 4815), | |
output1.resource | |
) | |
// 3 to 4 | |
val output2 = stateMachine.trigger(TestEvents.AnotherBasicEvent, resource) | |
assertEquals(TestState.State4, stateMachine.currentState) | |
assertEquals( | |
resource.copy(state = TestState.State4), | |
output2.resource | |
) | |
// 4 to 1 | |
val output3 = stateMachine.trigger(TestEvents.BasicEvent, resource) | |
assertEquals(TestState.State1, stateMachine.currentState) | |
assertEquals( | |
resource.copy(state = TestState.State1), | |
output3.resource | |
) | |
} | |
} | |
describe("transition with a guard and action") { | |
val stateMachine = StateMachine<TestState, TestResource>(TestState.INITIAL) { | |
onTransition { transition, res -> res.copy(state = transition.to) } | |
addTransitions( | |
StateTransition( | |
event = TestEvents.BasicEvent::class, | |
from = TestState.State1, | |
to = TestState.State2, | |
), | |
StateTransition( | |
event = TestEvents.EventWithNumber::class, | |
from = TestState.State2, | |
to = TestState.State3, | |
action = { e: StateMachine.Event, r: TestResource -> | |
r.copy(number = (e as TestEvents.EventWithNumber).someNumber) | |
}, | |
guard = { e: StateMachine.Event, _: TestResource -> | |
(e as TestEvents.EventWithNumber).someNumber > 1 | |
}, | |
), | |
StateTransition( | |
event = TestEvents.AnotherBasicEvent::class, | |
from = TestState.State3, | |
to = TestState.State4, | |
action = { _, _ -> null } | |
), | |
StateTransition( | |
event = TestEvents.BasicEvent::class, | |
from = TestState.State4, | |
to = TestState.State1 | |
) | |
) | |
} | |
describe("guard returns false") { | |
it("stops transition and action is not executed") { | |
val resource = TestResource(state = TestState.INITIAL) | |
stateMachine.trigger(TestEvents.BasicEvent, resource) | |
assertEquals(TestState.State2, stateMachine.currentState) | |
val output1 = | |
stateMachine.trigger(TestEvents.EventWithNumber(someNumber = -1), resource) | |
assertEquals(StateMachine.TriggerResult.StoppedByGuard, output1.result) | |
assertEquals(TestState.State2, stateMachine.currentState) | |
assertNull(output1.resource) | |
} | |
} | |
describe("guard returns true") { | |
it("changes state and executes action") { | |
val resource = TestResource(state = TestState.INITIAL) | |
stateMachine.trigger(TestEvents.BasicEvent, resource) | |
assertEquals(TestState.State2, stateMachine.currentState) | |
val output1 = | |
stateMachine.trigger(TestEvents.EventWithNumber(someNumber = 5), resource) | |
assertEquals(TestState.State3, stateMachine.currentState) | |
assertEquals( | |
"action should run and return modified resource", | |
resource.copy(state = TestState.State3, number = 5), | |
output1.resource | |
) | |
} | |
} | |
} | |
} | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment