# -*- coding: utf-8 -*-
"""
statemachine
~~~~~~~~~~~~
Module contains a statemachine to evaluate and visualize statemachine
transitions.
:copyright: (c) 2022 by Jochen Gerhaeusser.
:license: BSD, see LICENSE for details
"""
from __future__ import annotations
from dataclasses import (
dataclass, field, fields, Field)
from itertools import (
chain, permutations, repeat)
from typing import (
Any, Iterator, MutableMapping, Optional, Sequence)
import plotly.express as px
import plotly.figure_factory as ff
import plotly.graph_objects as go
from .. import Trace
# module exports
__all__ = [
'State2D',
'State3D',
'Statemachine',
]
[docs]@dataclass(eq=False)
class Statemachine(MutableMapping):
""" Statemachine data class to :attr:`evaluate` the state transitions
between :attr:`states` of a statemachine.
A *state* of the statemachine is defined by a tuple pair, consisting of
a unique integer *number* of the state, and the *label* of the state.
The :attr:`states` of a statemachine must define a consecutive, ascending
interval of integer :attr:`numbers`.
"""
#: States of the statemachine
states: dict[int, str | State2D | State3D]
#: Signal trace with the state numbers
signal: Trace = field(default_factory=Trace)
#: List with the state labels of the statemachine
labels: list[str] = field(default_factory=list)
#: Matrix with the counted transitions between the states
matrix: list[list[int]] = field(default_factory=list)
[docs] @classmethod
def create(cls,
states: int | Sequence[str | State2D | State3D],
start: int = 0,
**kwargs: Any) -> Statemachine:
""" Creates either a statemachine with the number of *states* without
labeling them, or a statemachine with states from a sequence containing
the labels or representations of the *states* to create.
:param states: number of states or a sequence with the labels or
representations for the :attr:`states` to create
:type states: int | Sequence[str | State2D | State3D]
:param int start: start number for the :attr:`states` to create
Default is ``0``.
"""
if isinstance(states, int):
nodes = dict(enumerate(range(states), start))
elif isinstance(states, (list, tuple)):
nodes = dict(enumerate(states, start))
else:
nodes = dict()
return cls(states=nodes, **kwargs)
def __post_init__(self) -> None:
if not self.labels:
for key, value in self.items():
if isinstance(value, (State2D, State3D)):
self.labels.append(value.name)
elif value is None:
self.labels.append(str(key))
else:
self.labels.append(str(value))
if self.signal:
# computes the transition matrix
self.evaluate()
elif not self.matrix:
# initialize zeroed transition matrix
self.matrix = self.zeroed_matrix()
@property
def numbers(self) -> list[int]:
""" List with the state numbers of the statemachine."""
return list(self.states.keys())
def __setitem__(self, key: int, value: str | State2D | State3D):
self.states[key] = value
def __getitem__(self, key: int) -> str | State2D | State3D:
return self.states[key]
def __delitem__(self, key: int):
raise NotImplemented()
def __iter__(self) -> Iterator[int]:
return iter(self.states)
def __len__(self) -> int:
return len(self.states)
[docs] def fields(self) -> tuple[Field, ...]:
""" Returns a tuple describing the fields of the data class."""
return fields(self)
[docs] def zeroed_matrix(self) -> list[list[int]]:
""" Returns the zeroed transition matrix of the statemachine.
"""
return [[0] * len(self) for _ in range(len(self))]
[docs] def zeroed_counters(self) -> dict[tuple[int, int], int]:
""" Returns the dictionary with the zeroed state transition counters
of the statemachine."""
return dict((transition, 0) for transition in permutations(self.states, 2))
[docs] def evaluate(self, signal: Optional[Trace] = None) -> Statemachine:
""" Counts the state transitions between the :attr:`states` of the
statemachine in the :attr:`signal`.
The :attr:`~Trace.samples` of the :attr:`signal` must contain only
integer :attr:`numbers` defined by the :attr:`states` of the state
machine.
:param signal: optional a new :attr:`signal` to evaluate
"""
# assign new signal to evaluate
if isinstance(signal, Trace):
self.signal = signal
# compute transition deltas from state signal
deltas = self.signal.delta()
# create state transition counters for the states
counters = self.zeroed_counters()
# count transitions between states
for state, delta in zip(self.signal, deltas):
if delta:
counters[(state - delta, state)] += 1
# reset transition matrix
self.matrix = self.zeroed_matrix()
# offset correction of the states
offset = min(self.states)
# fill transition matrix with counted transitions
for key, value in counters.items():
i, j = key
self.matrix[i - offset][j - offset] = value
return self
[docs] def flatten(self) -> Iterator[int]:
""" Returns an iterator to flatten the state transition :attr:`matrix`
into a list.
"""
return chain.from_iterable(self.matrix)
[docs] def data(self) -> list[list[str | int]]:
""" Returns the data table with the counted state transitions
between the :attr:`states` of the statemachine.
"""
# table data
data = list()
# add column labels
data.append([''] + self.labels)
# add rows
for name, values in zip(self.labels, self.matrix):
data.append([name] + values)
return data
[docs] def table(self, **kwargs: Any) -> go.Figure:
""" Returns a table figure for the state transition :attr:`matrix`
of the statemachine."""
return ff.create_table(self.data(), index=True, **kwargs)
[docs] def plot(self, **kwargs: Any) -> go.Heatmap:
""" Returns a heatmap plot for the state transition :attr:`matrix`
of the statemachine."""
settings = dict(
z=self.matrix,
x=self.labels,
y=self.labels
)
plot = go.Heatmap(**settings)
return plot.update(**kwargs)
[docs] def heatmap(self, **kwargs: Any) -> go.Figure:
""" Returns a heatmap figure for the state transition :attr:`matrix`
of the statemachine."""
# default settings
settings = dict(
x=self.labels,
y=self.labels,
text_auto=True,
color_continuous_scale='Blues'
)
settings.update(kwargs)
return px.imshow(self.matrix, **settings)
[docs] def flowchart(self, **kwargs: Any) -> go.Sankey:
""" Returns a sankey flow-chart plot for the state transition
:attr:`matrix` of the statemachine."""
# number of states
count = len(self)
settings = dict(
node=dict(label=self.labels),
link=dict(
source=list(
chain.from_iterable([repeat(i, count) for i in range(count)])),
target=list(range(count)) * count,
value=list(self.flatten())
)
)
plot = go.Sankey(**settings)
return plot.update(**kwargs)
[docs]@dataclass(unsafe_hash=True)
class State2D:
""" State represented by a 2-dimensional node."""
#: Name of the state
name: str
#: X-coordinate of the node
x: int
#: Y-coordinate of the node
y: int
#: Color of the node
color: str = 'Black'
[docs]@dataclass(unsafe_hash=True)
class State3D:
""" State represented by a 3-dimensional node."""
#: Name of the state
name: str
#: X-coordinate of the state node
x: int
#: Y-coordinate of the state node
y: int
#: Z-coordinate of the state node
z: int
#: Color of the node
color: str = 'Black'