add MNIST dataset, now without tensorflow dep

This commit is contained in:
unknown 2024-10-17 18:30:12 +03:00
parent fdd9772173
commit 06d7bf8b5f
6 changed files with 47 additions and 32 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,40 +1,47 @@
import gzip
import numpy as np import numpy as np
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# Load the MNIST dataset # Function to read the images file (3D)
(x_train, y_train), (x_test, y_test) = mnist.load_data() def read_images(filename):
with gzip.open(filename, 'rb') as f:
# Skip the first 16 bytes (magic number, number of images, height, and width)
f.read(16)
# Read the images (remaining data)
return np.frombuffer(f.read(), dtype=np.uint8).reshape(-1, 28, 28)
# Find the first '9' in the dataset # Function to convert a pixel value to an electrical stimulation level
index_of_nine = np.where(y_train == 9)[0][0] def pixel_to_stimulation(pixel_value, max_stimulation=5.0):
return pixel_value / 255.0 * max_stimulation
# Get the 28x28 pixel image of the number 9 # Function to read the labels file (1D)
image_of_nine = x_train[index_of_nine] def read_labels(filename):
with gzip.open(filename, 'rb') as f:
# Skip the first 8 bytes (magic number and label count)
f.read(8)
# Read the labels (remaining data)
return np.frombuffer(f.read(), dtype=np.uint8)
# Load train images and labels
train_images = read_images('dataset/train-images-idx3-ubyte.gz')
train_labels = read_labels('dataset/train-labels-idx1-ubyte.gz')
# Find all indices of the digit '9' in the training labels
indices_of_nines = np.where(train_labels == 9)[0]
# Get the first '9' image for demonstration
image_of_nine = train_images[indices_of_nines[0]]
# Normalize the pixel values to [0, 1] range # Normalize the pixel values to [0, 1] range
normalized_image = image_of_nine / 255.0 normalized_image = image_of_nine / 255.0
# Function to convert a pixel value to an electrical stimulation level # Convert the normalized image to an electrical stimulation pattern
def pixel_to_stimulation(pixel_value, max_stimulation=5.0):
"""
Converts a normalized pixel value (0.0 - 1.0) to an electrical stimulation (e.g., voltage).
Assumes max_stimulation is the maximum output voltage/current, e.g., 5V.
"""
return pixel_value * max_stimulation
# Apply the conversion to the entire image
stimulation_pattern = np.vectorize(pixel_to_stimulation)(normalized_image) stimulation_pattern = np.vectorize(pixel_to_stimulation)(normalized_image)
# Now `stimulation_pattern` is a 28x28 array of electrical stimulation levels (e.g., voltages) # Visualize the stimulation pattern
print("Electrical stimulation pattern for '9':")
print(stimulation_pattern)
# Create a plot to visualize the stimulation pattern
plt.figure(figsize=(5, 5)) plt.figure(figsize=(5, 5))
plt.imshow(stimulation_pattern, cmap='gray', interpolation='nearest') plt.imshow(stimulation_pattern, cmap='gray', interpolation='nearest')
plt.title("Electrical Stimulation Pattern for '9'") plt.title("Electrical Stimulation Pattern for '9'")
plt.axis('off') # Hide the axis for a cleaner look plt.axis('off')
plt.show() plt.show()
# Optionally, you could output this to an external device or system that applies electrical stimulation.

View File

@ -1,15 +1,23 @@
# Load the MNIST dataset import gzip
from tensorflow.keras.datasets import mnist
import numpy as np import numpy as np
# Function to read the labels file (1D)
def read_labels(filename):
with gzip.open(filename, 'rb') as f:
# Skip the first 8 bytes (magic number and label count)
f.read(8)
# Read the labels (remaining data)
return np.frombuffer(f.read(), dtype=np.uint8)
(x_train, y_train), (x_test, y_test) = mnist.load_data() # Load train labels
train_labels = read_labels('dataset/train-labels-idx1-ubyte.gz')
# Find all indices of the digit '9' in the training data # Find all indices of the digit '9' in the training labels
indices_of_nines = np.where(y_train == 9)[0] indices_of_nines = np.where(train_labels == 9)[0]
# Count how many '9's are there in the training set # Count how many '9's are there in the training set
num_nines = len(indices_of_nines) num_nines = len(indices_of_nines)
# Display the count # Output the count and the first 10 indices
num_nines print(f"Number of '9's in the training set: {num_nines}")
print(f"First 10 indices of '9': {indices_of_nines[:10]}")