add MNIST dataset, now without tensorflow dep
This commit is contained in:
parent
fdd9772173
commit
06d7bf8b5f
BIN
dataset/t10k-images-idx3-ubyte.gz
Normal file
BIN
dataset/t10k-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
dataset/t10k-labels-idx1-ubyte.gz
Normal file
BIN
dataset/t10k-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
dataset/train-images-idx3-ubyte.gz
Normal file
BIN
dataset/train-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
dataset/train-labels-idx1-ubyte.gz
Normal file
BIN
dataset/train-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
|
@ -1,40 +1,47 @@
|
|||
import gzip
|
||||
import numpy as np
|
||||
from tensorflow.keras.datasets import mnist
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Load the MNIST dataset
|
||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||
# Function to read the images file (3D)
|
||||
def read_images(filename):
|
||||
with gzip.open(filename, 'rb') as f:
|
||||
# Skip the first 16 bytes (magic number, number of images, height, and width)
|
||||
f.read(16)
|
||||
# Read the images (remaining data)
|
||||
return np.frombuffer(f.read(), dtype=np.uint8).reshape(-1, 28, 28)
|
||||
|
||||
# Find the first '9' in the dataset
|
||||
index_of_nine = np.where(y_train == 9)[0][0]
|
||||
# Function to convert a pixel value to an electrical stimulation level
|
||||
def pixel_to_stimulation(pixel_value, max_stimulation=5.0):
|
||||
return pixel_value / 255.0 * max_stimulation
|
||||
|
||||
# Get the 28x28 pixel image of the number 9
|
||||
image_of_nine = x_train[index_of_nine]
|
||||
# Function to read the labels file (1D)
|
||||
def read_labels(filename):
|
||||
with gzip.open(filename, 'rb') as f:
|
||||
# Skip the first 8 bytes (magic number and label count)
|
||||
f.read(8)
|
||||
# Read the labels (remaining data)
|
||||
return np.frombuffer(f.read(), dtype=np.uint8)
|
||||
|
||||
|
||||
# Load train images and labels
|
||||
train_images = read_images('dataset/train-images-idx3-ubyte.gz')
|
||||
train_labels = read_labels('dataset/train-labels-idx1-ubyte.gz')
|
||||
|
||||
# Find all indices of the digit '9' in the training labels
|
||||
indices_of_nines = np.where(train_labels == 9)[0]
|
||||
|
||||
# Get the first '9' image for demonstration
|
||||
image_of_nine = train_images[indices_of_nines[0]]
|
||||
|
||||
# Normalize the pixel values to [0, 1] range
|
||||
normalized_image = image_of_nine / 255.0
|
||||
|
||||
# Function to convert a pixel value to an electrical stimulation level
|
||||
def pixel_to_stimulation(pixel_value, max_stimulation=5.0):
|
||||
"""
|
||||
Converts a normalized pixel value (0.0 - 1.0) to an electrical stimulation (e.g., voltage).
|
||||
Assumes max_stimulation is the maximum output voltage/current, e.g., 5V.
|
||||
"""
|
||||
return pixel_value * max_stimulation
|
||||
|
||||
# Apply the conversion to the entire image
|
||||
# Convert the normalized image to an electrical stimulation pattern
|
||||
stimulation_pattern = np.vectorize(pixel_to_stimulation)(normalized_image)
|
||||
|
||||
# Now `stimulation_pattern` is a 28x28 array of electrical stimulation levels (e.g., voltages)
|
||||
print("Electrical stimulation pattern for '9':")
|
||||
print(stimulation_pattern)
|
||||
|
||||
|
||||
# Create a plot to visualize the stimulation pattern
|
||||
# Visualize the stimulation pattern
|
||||
plt.figure(figsize=(5, 5))
|
||||
plt.imshow(stimulation_pattern, cmap='gray', interpolation='nearest')
|
||||
plt.title("Electrical Stimulation Pattern for '9'")
|
||||
plt.axis('off') # Hide the axis for a cleaner look
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
# Optionally, you could output this to an external device or system that applies electrical stimulation.
|
||||
|
|
|
@ -1,15 +1,23 @@
|
|||
# Load the MNIST dataset
|
||||
from tensorflow.keras.datasets import mnist
|
||||
import gzip
|
||||
import numpy as np
|
||||
|
||||
# Function to read the labels file (1D)
|
||||
def read_labels(filename):
|
||||
with gzip.open(filename, 'rb') as f:
|
||||
# Skip the first 8 bytes (magic number and label count)
|
||||
f.read(8)
|
||||
# Read the labels (remaining data)
|
||||
return np.frombuffer(f.read(), dtype=np.uint8)
|
||||
|
||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||
# Load train labels
|
||||
train_labels = read_labels('dataset/train-labels-idx1-ubyte.gz')
|
||||
|
||||
# Find all indices of the digit '9' in the training data
|
||||
indices_of_nines = np.where(y_train == 9)[0]
|
||||
# Find all indices of the digit '9' in the training labels
|
||||
indices_of_nines = np.where(train_labels == 9)[0]
|
||||
|
||||
# Count how many '9's are there in the training set
|
||||
num_nines = len(indices_of_nines)
|
||||
|
||||
# Display the count
|
||||
num_nines
|
||||
# Output the count and the first 10 indices
|
||||
print(f"Number of '9's in the training set: {num_nines}")
|
||||
print(f"First 10 indices of '9': {indices_of_nines[:10]}")
|
||||
|
|
Loading…
Reference in New Issue
Block a user