---
file_format: mystnb
kernelspec:
  name: python3
---

(network_reconstruction)=

# Network Reconstruction

Network reconstruction is the process of inferring the structure of a network from time
series data. This is a crucial step in understanding the relationships between different
nodes in a complex system.
Put simply, network reconstruction applies connectivity measures to each pair of time
series.
This results in a $(n \times n)$ matrix of $p$-values and lags.

## The `reconstruct_network()` Function

The `delaynet` package provides the
{func}`~delaynet.network_reconstruction.reconstruct_network` function to generate
$p$-value matrices by applying connectivity measures to pairs of time series.
As described in the {ref}`connectivity_sec` section, connectivity measures return
$p$-values that indicate the strength of connections between time series.
Therefore, the $p$-value matrix in network reconstruction represents a matrix of
$p$-values, where lower values indicate stronger connections.

## Complete Network Reconstruction Example

This comprehensive example demonstrates the entire workflow of network reconstruction,
from data generation to network visualization and validation. We'll use synthetic data
with a known ground truth to evaluate the reconstruction performance.

### Data Generation and Preprocessing

First, we generate synthetic time series data using a delayed causal network (DCN)
model. This approach creates time series with known causal relationships, allowing us to
validate our reconstruction results.

```{code-cell}
:tags: [hide-input]
import networkx as nx
import numpy as np
from matplotlib import pyplot as plt

import delaynet as dn
from numpy.random import default_rng

adjacency_matrix, weight_matrix, time_series = dn.preparation.gen_delayed_causal_network(
    ts_len=1000,                # Length of time series
    n_nodes=8,                  # Number of nodes
    l_dens=0.5,                 # Density of the adjacency matrix
    wm_min_max=(0.5, 1.5),      # Min and max of the weight matrix
    rng=default_rng(1612757)
)

print(f"Generated time series: {time_series.shape}")
print(f"Adjacency matrix: {adjacency_matrix.shape}")
print(f"Weight matrix: {weight_matrix.shape}")
```

The generated data contains 8 time series of length 1000, along with the true adjacency
and weight matrices that define the underlying network structure. Before proceeding with
network reconstruction, we apply detrending to remove trends and make the time series
stationary.

Note that we use `axis=1` in the detrending function because we have multiple time
series arranged as rows, and we want to detrend each time series individually across its
temporal dimension.

```{code-cell}
time_series = dn.detrend(time_series, "delta", axis=1)
print(f"Detrended time series: {time_series.shape}")
```

### Network Reconstruction

Now we apply the {func}`~delaynet.network_reconstruction.reconstruct_network` function
to infer the network structure from the preprocessed time series data. We use Granger
causality as our connectivity measure, which is particularly well-suited for detecting
directed causal relationships in time series data.

```{code-cell}
# Reconstruct with Granger Casuality
weights, lags = dn.reconstruct_network(
    time_series.T, connectivity_measure="gc", lag_steps=10
)
```

### Visualizing the Results

The reconstruction function returns two matrices: a $p$-value matrix indicating the
strength of connections (lower values = stronger connections) and a lag matrix showing
the optimal time delays. We visualize these results as a heatmap with the $p$-values as
colors and the optimal lags as text annotations.

```{code-cell}
:tags: [hide-input]
:width: 200px
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(weights)

# Add lag numbers as text annotations
for i in range(weights.shape[0]):
    for j in range(weights.shape[1]):
        ax.text(
            j, i, str(int(lags[i, j])), ha="center", va="center", color="w"
        )
plt.colorbar(im)
plt.show()
```

### Network Pruning

To convert the continuous $p$-value matrix into a binary adjacency matrix, we apply
statistical thresholding. We use a significance level (α) of 0.02, meaning we only
consider connections with $p$-values below this threshold as statistically significant.
This pruning step is crucial for removing weak or spurious connections and focusing on
the most reliable network edges.

```{code-cell}
# Prune weight matrix
pruned_weights = weights.copy()
threshold = 0.02  # Adjust threshold as needed
pruned_weights = 1 * (pruned_weights < threshold)
# pruned_weights = 1 - pruned_weights
# set diagonal to 0
pruned_weights[np.diag_indices_from(pruned_weights)] = 0
```

### Validation Against Ground Truth

In real-world applications, the true network structure is typically unknown. However,
since we used synthetic data, we can compare our reconstructed network against the
ground truth adjacency matrix. This validation step helps us assess the accuracy of our
reconstruction method.

The visualization shows three matrices: the original (ground truth) adjacency matrix,
our reconstructed and pruned adjacency matrix, and their difference. In the difference
plot, blue regions indicate edges that exist only in the original network (false
negatives), while red regions show edges that were inferred but don't exist in the true
network (false positives).

```{code-cell}
:tags: [hide-input]
# Plot adjacency matrix and pruned one in one plot next to each other
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5), dpi=300)

im1 = ax1.imshow(adjacency_matrix, cmap="gray_r")
ax1.set_title("Original Adjacency")
ax1.figure.colorbar(im1, ax=ax1)

im2 = ax2.imshow(pruned_weights, cmap="gray_r")
ax2.set_title("Pruned Adjacency")
ax2.figure.colorbar(im2, ax=ax2)

# Diff
diff = adjacency_matrix - pruned_weights
im3 = ax3.imshow(diff, cmap="RdBu", vmin=-1, vmax=1)
ax3.set_title("Difference (blue=original, red=pruned)")
ax3.figure.colorbar(im3, ax=ax3)

plt.tight_layout()
plt.show()
```

### Network Graph Visualization

To better understand the reconstruction results, we create a network graph that visually
distinguishes between different types of edges. This representation makes it easier to
identify correctly reconstructed connections, missed connections, and false discoveries.

```{code-cell}
G = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph)
G_pruned = nx.from_numpy_array(pruned_weights, create_using=nx.DiGraph)
G_union = nx.compose(G, G_pruned)
```

```{code-cell}
:tags: [hide-input]
edge_colors = []
for edge in G_union.edges:
    if edge in G.edges and edge in G_pruned.edges:
        edge_colors.append("black")
    elif edge in G_pruned.edges:
        edge_colors.append("#D55E00")
    else:
        edge_colors.append("#0173B2")
# plot
fig, ax = plt.subplots()
nx.draw(
    G_union,
    pos=nx.kamada_kawai_layout(G_union),
    with_labels=True,
    node_color="lightblue",
    node_size=500,
    edge_color=edge_colors,
    arrowsize=20,
    # curved arrows
    arrowstyle="->",
    connectionstyle="arc3,rad=0.2",
)
plt.title("Original and Reconstructed Network")
plt.show()
```

This example showed the reconstruction of a network from synthetic time series data with
known ground truth, demonstrating the complete workflow from data generation through
validation and providing a framework for evaluating reconstruction accuracy in
controlled scenarios.

## API Reference and Usage Examples

### Basic Usage

```{code-cell}
import numpy as np
from delaynet.network_reconstruction import reconstruct_network

# Generate sample data: 100 time points, 5 nodes
np.random.seed(42)  # For reproducible results
data = np.random.randn(100, 5)

# Reconstruct network using linear correlation
weights, lags = reconstruct_network(data, "linear_correlation", lag_steps=5)

print(f"P-value matrix shape: {weights.shape}")
print(f"Lag matrix shape: {lags.shape}")
print(f"P-value matrix:\n{weights}")
print(f"Lag matrix:\n{lags}")
```

### Understanding the Output

The function returns two matrices:

1. **$p$-value Matrix**: Contains $p$-values representing the strength of connections
   between nodes. Lower $p$-values indicate stronger connections.
2. **Lag Matrix**: Contains the optimal time lags at which the strongest connections
   were found.

```{code-cell} ipython3
# Analyze the results
print("Strong connections (p < 0.05):")
strong_connections = np.where(weights < 0.05)
for i, j in zip(strong_connections[0], strong_connections[1]):
    if i != j:
        print(f"Node {i} -> Node {j}: p-value = {weights[i,j]:.4f}, lag = {lags[i,j]}")
```

### Using Different Connectivity Measures

You can use any of the available connectivity measures:

```{code-cell}
# Using transfer entropy
weights_te, lags_te = reconstruct_network(data, "transfer_entropy", approach = "ksg", lag_steps=3)

# Using mutual information
weights_mi, lags_mi = reconstruct_network(data, "mutual_information", approach = "ksg", lag_steps=3)

print(f"Transfer entropy p-values:\n{weights_te}")
print(f"Mutual information p-values:\n{weights_mi}")
```

### Custom Connectivity Measures

You can also provide your own connectivity measure as a callable:

```{code-cell}
def custom_correlation_metric(ts1, ts2, lag_steps=None):
    """Custom connectivity measure based on correlation."""
    # Simple example: return absolute correlation as p-value and lag 1
    correlation = abs(np.corrcoef(ts1, ts2)[0, 1])
    return correlation, 1  # not really a p-value and lag - for demonstration purposes

# Use custom metric
weights_custom, lags_custom = reconstruct_network(data, custom_correlation_metric, lag_steps=5)
print(f"Custom metric p-values:\n{weights_custom}")
```

## Function Reference

```{eval-rst}
.. autofunction:: delaynet.network_reconstruction.reconstruct_network
```

## Next Steps

After reconstructing a network, the next step is typically to analyze its properties and
extract meaningful insights. The {ref}`network_analysis` section provides tools for
pruning networks based on statistical significance and calculating various network
metrics such as centrality measures, link density, and global efficiency.
