This commit is contained in:
unknown 2024-10-17 19:24:48 +03:00
parent 06d7bf8b5f
commit 96d23c1b59
3 changed files with 36 additions and 25 deletions

View File

@ -0,0 +1,26 @@
#include <avr/pgmspace.h>
#include "stimulation_pattern.h"
const int stimulationPin = 5;
const int delayBetweenStimulations = 10; // milliseconds
void setup() {
pinMode(stimulationPin, OUTPUT);
Serial.begin(9600);
}
void loop() {
for (uint16_t i = 0; i < PATTERN_SIZE; i++) {
uint8_t stimulationValue = pgm_read_byte(&stimulation_pattern[i]);
analogWrite(stimulationPin, stimulationValue);
// Print the current stimulation value for debugging
Serial.print("Applying stimulation: ");
Serial.println(stimulationValue);
delay(delayBetweenStimulations);
}
// Add a longer delay at the end of each complete pattern
delay(1000);
}

View File

@ -0,0 +1,2 @@
const PROGMEM uint8_t stimulation_pattern[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 148, 209, 253, 253, 113, 87, 148, 55, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 87, 232, 252, 253, 189, 209, 252, 252, 253, 168, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 57, 242, 252, 190, 65, 5, 12, 182, 252, 253, 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 252, 252, 183, 14, 0, 0, 92, 252, 252, 225, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 132, 253, 252, 145, 14, 0, 0, 0, 215, 252, 252, 79, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 126, 253, 246, 176, 9, 0, 0, 8, 78, 245, 253, 129, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 232, 252, 176, 0, 0, 0, 36, 201, 252, 252, 169, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 22, 252, 252, 30, 22, 119, 197, 241, 253, 252, 250, 77, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 231, 252, 253, 252, 252, 252, 226, 227, 252, 231, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 234, 253, 217, 137, 42, 24, 192, 252, 143, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 61, 255, 253, 109, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 71, 253, 252, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253, 252, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 71, 253, 252, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 106, 253, 252, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 255, 253, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 218, 252, 56, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 252, 189, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 184, 252, 169, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 146, 252, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
const uint16_t PATTERN_SIZE = 784;

View File

@ -1,47 +1,30 @@
import gzip import gzip
import numpy as np import numpy as np
import matplotlib.pyplot as plt
# Function to read the images file (3D)
def read_images(filename): def read_images(filename):
with gzip.open(filename, 'rb') as f: with gzip.open(filename, 'rb') as f:
# Skip the first 16 bytes (magic number, number of images, height, and width)
f.read(16) f.read(16)
# Read the images (remaining data)
return np.frombuffer(f.read(), dtype=np.uint8).reshape(-1, 28, 28) return np.frombuffer(f.read(), dtype=np.uint8).reshape(-1, 28, 28)
# Function to convert a pixel value to an electrical stimulation level
def pixel_to_stimulation(pixel_value, max_stimulation=5.0): def pixel_to_stimulation(pixel_value, max_stimulation=5.0):
return pixel_value / 255.0 * max_stimulation return int(pixel_value / 255.0 * max_stimulation * 51) # Convert to 0-255 range
# Function to read the labels file (1D)
def read_labels(filename): def read_labels(filename):
with gzip.open(filename, 'rb') as f: with gzip.open(filename, 'rb') as f:
# Skip the first 8 bytes (magic number and label count)
f.read(8) f.read(8)
# Read the labels (remaining data)
return np.frombuffer(f.read(), dtype=np.uint8) return np.frombuffer(f.read(), dtype=np.uint8)
# Load train images and labels
train_images = read_images('dataset/train-images-idx3-ubyte.gz') train_images = read_images('dataset/train-images-idx3-ubyte.gz')
train_labels = read_labels('dataset/train-labels-idx1-ubyte.gz') train_labels = read_labels('dataset/train-labels-idx1-ubyte.gz')
# Find all indices of the digit '9' in the training labels
indices_of_nines = np.where(train_labels == 9)[0] indices_of_nines = np.where(train_labels == 9)[0]
# Get the first '9' image for demonstration
image_of_nine = train_images[indices_of_nines[0]] image_of_nine = train_images[indices_of_nines[0]]
# Normalize the pixel values to [0, 1] range stimulation_pattern = np.vectorize(pixel_to_stimulation)(image_of_nine)
normalized_image = image_of_nine / 255.0
# Convert the normalized image to an electrical stimulation pattern # Output the stimulation pattern as a compact byte array
stimulation_pattern = np.vectorize(pixel_to_stimulation)(normalized_image) with open('stimulation_pattern.h', 'w') as f:
f.write("const PROGMEM uint8_t stimulation_pattern[] = {")
# Visualize the stimulation pattern f.write(", ".join(map(str, stimulation_pattern.flatten())))
plt.figure(figsize=(5, 5)) f.write("};\n")
plt.imshow(stimulation_pattern, cmap='gray', interpolation='nearest') f.write(f"const uint16_t PATTERN_SIZE = {stimulation_pattern.size};\n")
plt.title("Electrical Stimulation Pattern for '9'")
plt.axis('off')
plt.show()