LLMind/gen_stimulat.py

48 lines
1.6 KiB
Python

import gzip
import numpy as np
import matplotlib.pyplot as plt
# Function to read the images file (3D)
def read_images(filename):
with gzip.open(filename, 'rb') as f:
# Skip the first 16 bytes (magic number, number of images, height, and width)
f.read(16)
# Read the images (remaining data)
return np.frombuffer(f.read(), dtype=np.uint8).reshape(-1, 28, 28)
# Function to convert a pixel value to an electrical stimulation level
def pixel_to_stimulation(pixel_value, max_stimulation=5.0):
return pixel_value / 255.0 * max_stimulation
# Function to read the labels file (1D)
def read_labels(filename):
with gzip.open(filename, 'rb') as f:
# Skip the first 8 bytes (magic number and label count)
f.read(8)
# Read the labels (remaining data)
return np.frombuffer(f.read(), dtype=np.uint8)
# Load train images and labels
train_images = read_images('dataset/train-images-idx3-ubyte.gz')
train_labels = read_labels('dataset/train-labels-idx1-ubyte.gz')
# Find all indices of the digit '9' in the training labels
indices_of_nines = np.where(train_labels == 9)[0]
# Get the first '9' image for demonstration
image_of_nine = train_images[indices_of_nines[0]]
# Normalize the pixel values to [0, 1] range
normalized_image = image_of_nine / 255.0
# Convert the normalized image to an electrical stimulation pattern
stimulation_pattern = np.vectorize(pixel_to_stimulation)(normalized_image)
# Visualize the stimulation pattern
plt.figure(figsize=(5, 5))
plt.imshow(stimulation_pattern, cmap='gray', interpolation='nearest')
plt.title("Electrical Stimulation Pattern for '9'")
plt.axis('off')
plt.show()