Source code for delaynet.connectivity
"""Module to provide a unified interface for all connectivity metrics."""
from collections.abc import Callable
from numpy import ndarray
from .connectivities import (
__all_connectivity_metrics_names__,
__all_connectivity_metrics_names_simple__,
)
from .decorators import connectivity as connectivity_decorator
Metric = str | Callable[[ndarray, ndarray, ...], tuple[float, int]]
[docs]
def connectivity(
ts1: ndarray,
ts2: ndarray,
/,
metric: Metric,
*args,
lag_steps: int | list[int] | None = None,
**kwargs,
) -> tuple[float, int]:
"""
Calculate connectivity between two time series using a given metric.
Keyword arguments are forwarded to the metric function.
The metrics can be either a string or a function, implementing a connectivity
metric.
Find the metric string specifier using :func:`show_connectivity_metrics`.
(Find all in submodule :mod:`delaynet.connectivities`, names are stored in
:attr:`delaynet.connectivities.__all_connectivity_metrics__`)
If a `callable` is given, it should take two time series as input and return a
`tuple` of `float` and `int`.
:param ts1: First time series. Positional only.
:type ts1: numpy.ndarray
:param ts2: Second time series. Positional only.
:type ts2: numpy.ndarray
:param metric: Metric to use.
:type metric: str or Callable
:param args: Positional arguments forwarded to the connectivity function, see
documentation.
:type args: list
:param lag_steps: The number of lag steps to consider. Required.
Can be integer for [1, ..., num], or a list of integers.
:type lag_steps: int | list[int] | None
:param kwargs: Keyword arguments forwarded to the connectivity function, see
documentation.
:return: Connectivity value and lag.
:rtype: tuple of float and int
:raises ValueError: If the metric is unknown. Given as string.
:raises ValueError: If the metric returns an invalid value. Given a Callable.
:raises ValueError: If the metric is neither a string nor a Callable.
"""
kwargs["lag_steps"] = lag_steps
if isinstance(metric, str):
metric = metric.lower()
if metric not in __all_connectivity_metrics_names__:
raise ValueError(f"Unknown metric: {metric}")
return __all_connectivity_metrics_names__[metric](ts1, ts2, **kwargs)
if not callable(metric):
raise ValueError(
f"Invalid connectivity metric: {metric}. Must be string or callable."
)
# connectivity metric is a callable,
# add decorator to assure correct kwargs, type and shape
return connectivity_decorator()(metric)(ts1, ts2, *args, **kwargs)
[docs]
def show_connectivity_metrics():
"""Pretty print all available connectivity metrics."""
print("Available connectivity metrics:")
for metric, aliases in __all_connectivity_metrics_names_simple__.items():
print(f"\nMetric: {metric}")
print("Aliases:")
for alias in aliases:
print(f" - {alias}")
print("\n")