103 lines
3.4 KiB
Python
103 lines
3.4 KiB
Python
import serial
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from serial.tools import list_ports
|
|
import time
|
|
import sys
|
|
|
|
# Serial port configuration
|
|
DEFAULT_PORT = 'COM20'
|
|
BAUD_RATE = 115200
|
|
TIMEOUT = 5 # 5 seconds timeout
|
|
|
|
# MNIST digit size
|
|
DIGIT_SIZE = 28
|
|
|
|
def find_arduino_port():
|
|
ports = list_ports.comports()
|
|
for port in ports:
|
|
if 'Arduino' in port.description:
|
|
return port.device
|
|
return None
|
|
|
|
def read_serial_data(port, baud_rate, timeout):
|
|
try:
|
|
with serial.Serial(port, baud_rate, timeout=timeout) as ser:
|
|
print(f"Connected to {port}")
|
|
data = []
|
|
start_time = time.time()
|
|
while len(data) < DIGIT_SIZE * DIGIT_SIZE:
|
|
if time.time() - start_time > timeout:
|
|
print("Timeout: No data received for 5 seconds.")
|
|
break
|
|
try:
|
|
line = ser.readline().decode('ascii').strip()
|
|
if line:
|
|
value = int(line)
|
|
if 0 <= value <= 255: # Ensure value is in valid range
|
|
data.append(value)
|
|
start_time = time.time() # Reset the timeout
|
|
else:
|
|
print(f"Warning: Received out-of-range value: {value}. Skipping.")
|
|
except (UnicodeDecodeError, ValueError) as e:
|
|
print(f"Received invalid data: {e}. Skipping.")
|
|
return np.array(data)
|
|
except serial.SerialException as e:
|
|
print(f"Error opening serial port {port}: {e}")
|
|
return np.array([])
|
|
|
|
def plot_digit(data):
|
|
if len(data) < DIGIT_SIZE * DIGIT_SIZE:
|
|
print(f"Warning: Received only {len(data)} data points. Padding with zeros.")
|
|
data = np.pad(data, (0, DIGIT_SIZE * DIGIT_SIZE - len(data)), 'constant')
|
|
|
|
digit = data.reshape(DIGIT_SIZE, DIGIT_SIZE)
|
|
|
|
# Create a larger figure (doubled size)
|
|
plt.figure(figsize=(10, 10))
|
|
plt.imshow(digit, cmap='gray', interpolation='nearest')
|
|
plt.title("Decoded Digit", fontsize=24)
|
|
plt.axis('off')
|
|
|
|
# Add text annotations for each pixel value
|
|
for i in range(DIGIT_SIZE):
|
|
for j in range(DIGIT_SIZE):
|
|
value = digit[i, j]
|
|
plt.text(j, i, f'{value}', ha='center', va='center', color='red', fontsize=6)
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
def main():
|
|
port = DEFAULT_PORT
|
|
if not serial.tools.list_ports.comports(include_links=False):
|
|
print("No serial ports found. Please check your connections.")
|
|
return
|
|
|
|
if DEFAULT_PORT not in [p.device for p in serial.tools.list_ports.comports()]:
|
|
print(f"Warning: {DEFAULT_PORT} not found. Attempting to find Arduino port.")
|
|
port = find_arduino_port()
|
|
if not port:
|
|
print("Arduino not found. Please specify the COM port manually.")
|
|
return
|
|
|
|
print("Reading data from serial port...")
|
|
data = read_serial_data(port, BAUD_RATE, TIMEOUT)
|
|
|
|
if len(data) > 0:
|
|
print(f"Received {len(data)} data points. Plotting digit...")
|
|
plot_digit(data)
|
|
else:
|
|
print("No data received. Please check the Arduino connection and code.")
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
print("\nScript terminated by user.")
|
|
except Exception as e:
|
|
print(f"An unexpected error occurred: {e}")
|
|
finally:
|
|
print("Script execution completed.")
|
|
sys.exit(0)
|