Source code for delaynet.utils.lag_steps
"""Utility functions to handle lag steps."""
from typing import ParamSpec, Protocol
import numpy as np
from .logging import logging
[docs]
def assure_lag_list(lag_steps: int | list[int]) -> list[int]:
"""Ensure that ``lag_steps`` is a list of lags.
If ``lag_steps`` is an integer, it will be
converted to a list containing integers from 1 to ``lag_steps``.
If ``lag_steps`` is already a list, it will be checked to ensure that
all elements are integers.
:param lag_steps: An integer >= 1 or a list of integers.
:type lag_steps: int | list[int]
:return: A list of integers
:type: list[int]
:raises ValueError: If ``lag_steps`` is not an integer >= 1 or a list of integers.
"""
if isinstance(lag_steps, int):
if lag_steps <= 0:
raise ValueError("`lag_steps` must be a positive integer or list of such.")
lag_steps = list(range(1, lag_steps + 1))
elif not isinstance(lag_steps, list) or (
isinstance(lag_steps, list) and not all(isinstance(x, int) for x in lag_steps)
):
raise ValueError("`lag_steps` must be an integer or a list of integers.")
if any(x <= 0 for x in lag_steps):
logging.warning(
"Some elements in `lag_steps` are non-positive. "
"This might produce unscientific results."
)
return lag_steps
P = ParamSpec("P")
[docs]
class Connectivity(Protocol[P]):
"""A protocol for connectivity metrics.
Connectivity metrics are rigidly typed in their first three parameters:
the time series and one lag step. The rest are optional keyword arguments.
"""
def __call__(
self,
ts1: np.ndarray,
ts2: np.ndarray,
lag: int,
*args: P.args,
**kwargs: P.kwargs,
) -> float: ...
[docs]
def find_optimal_lag(
metric_func: Connectivity,
ts1,
ts2,
lag_steps: list,
op=min,
**kwargs,
):
"""Find the optimal value and lag for a given metric function.
The optimal value and lag are determined by applying a given operation
to a list of values obtained by applying ``metric_func``
for each lag in ``lag_steps``.
The operation can be `min`, `max`, or any other operation that takes a list of
values and returns a single value.
If ``metric_func`` returns a *p*-value, the operator should be the minimum
(default optional parameter).
:param metric_func: Function to compute the metric for a given lag step.
Accepts time series ``ts1`` and ``ts2``, a lag step ``lag``,
and any additional keyword arguments.
:type metric_func: Connectivity
:param ts1: First time series.
:type ts1: numpy.ndarray
:param ts2: Second time series.
:type ts2: numpy.ndarray
:param lag_steps: Time lags to consider.
Needs to be a list of integers.
:type lag_steps: list
:param op: Operator to find the optimal lag step
(e.g., default :py:func:`min` or :py:func:`max`).
:type op: Callable
:param kwargs: Additional keyword arguments to pass to the metric function.
:return: Optimal metric value and corresponding lag step.
:rtype: tuple[float, int]
"""
all_values = [metric_func(ts1, ts2, lag_step, **kwargs) for lag_step in lag_steps]
idx_optimal = op(range(len(all_values)), key=all_values.__getitem__)
return all_values[idx_optimal], lag_steps[idx_optimal]