import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, ConnectionPatch

# Set up the figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
fig.suptitle('Dense vs. Sparse Interactions in Neural Networks', fontsize=16, fontweight='bold')

# --- Plot 1: Dense Connections ---
# Create a simple 4x4 input image
input_size = 4
dense_input = np.zeros((input_size, input_size))

# Draw the input grid
for i in range(input_size):
    for j in range(input_size):
        ax1.add_patch(Rectangle((j, i), 1, 1, fill=True, edgecolor='black', facecolor='lightgray', lw=1))
        ax1.text(j + 0.5, i + 0.5, f'x{i},{j}', ha='center', va='center', fontsize=8)

# Draw the hidden layer neurons on the right
hidden_neurons_x = input_size + 2
num_hidden_neurons = 5  # Using 5 for clarity, but imagine 1000!

for i in range(num_hidden_neurons):
    y_pos = (input_size - num_hidden_neurons) / 2 + i  # Center them vertically
    ax1.add_patch(Rectangle((hidden_neurons_x, y_pos), 0.8, 0.8, fill=True, edgecolor='blue', facecolor='lightblue', lw=1))
    ax1.text(hidden_neurons_x + 0.4, y_pos + 0.4, f'h{i}', ha='center', va='center', fontsize=10, weight='bold')
    
    # Draw connections from EVERY input to EVERY hidden neuron (a mess!)
    for x_i in range(input_size):
        for y_i in range(input_size):
            # Draw a line from input pixel (x_i, y_i) to hidden neuron i
            # Make them faint to show how overwhelming it is
            ax1.plot([x_i + 0.5, hidden_neurons_x], [y_i + 0.5, y_pos + 0.4], 'gray', linestyle='-', alpha=0.2, lw=0.5)

ax1.set_xlim(-1, hidden_neurons_x + 2)
ax1.set_ylim(-1, input_size + 1)
ax1.set_title('Dense Connections\n(Overwhelming & Global)', fontsize=14, color='red')
ax1.set_aspect('equal')
ax1.axis('off')  # Turn off the axis

# --- Plot 2: Sparse (Convolutional) Connections ---
# Create the same 4x4 input image
conv_input = np.zeros((input_size, input_size))
for i in range(input_size):
    for j in range(input_size):
        ax2.add_patch(Rectangle((j, i), 1, 1, fill=True, edgecolor='black', facecolor='lightgray', lw=1))
        ax2.text(j + 0.5, i + 0.5, f'x{i},{j}', ha='center', va='center', fontsize=8)

# Draw the convolutional kernel (receptive field)
kernel_size = 2
kernel_pos = (1, 1)  # Top-left corner of the kernel
# Highlight the local patch the kernel is looking at
ax2.add_patch(Rectangle(kernel_pos, kernel_size, kernel_size, fill=False, edgecolor='red', linestyle='-', lw=3, alpha=1))

# Draw the single hidden neuron (output activation) on the right
hidden_neuron_x = input_size + 2
hidden_neuron_y = kernel_pos[0] + 0.5 # Center it with the kernel
ax2.add_patch(Rectangle((hidden_neuron_x, hidden_neuron_y), 0.8, 0.8, fill=True, edgecolor='blue', facecolor='lightblue', lw=1))
ax2.text(hidden_neuron_x + 0.4, hidden_neuron_y + 0.4, 'h', ha='center', va='center', fontsize=10, weight='bold')

# Draw connections ONLY from the local receptive field to the neuron
for i in range(kernel_size):
    for j in range(kernel_size):
        # Draw a strong, clear line from the kernel patch to the neuron
        ax2.plot([kernel_pos[0] + j + 0.5, hidden_neuron_x], [kernel_pos[1] + i + 0.5, hidden_neuron_y + 0.4], 'red', linestyle='-', lw=2)
        # Label the weights
        ax2.text((kernel_pos[0] + j + 0.5 + hidden_neuron_x)/2, (kernel_pos[1] + i + 0.5 + hidden_neuron_y + 0.4)/2, f'w{i},{j}', ha='center', va='center', fontsize=8, color='darkred')

# Add an arrow to show the kernel sliding
ax2.arrow(3.5, 3.5, 1.5, 0, head_width=0.3, head_length=0.2, fc='green', ec='green', lw=2)
ax2.text(4.5, 3.8, 'Kernel Slides', ha='center', va='center', fontsize=10, color='green')

ax2.set_xlim(-1, hidden_neuron_x + 2)
ax2.set_ylim(-1, input_size + 1)
ax2.set_title('Sparse (Convolutional) Connections\n(Efficient & Local)', fontsize=14, color='green')
ax2.set_aspect('equal')
ax2.axis('off')  # Turn off the axis