24 lines
765 B
Python
24 lines
765 B
Python
import gzip
|
|
import numpy as np
|
|
|
|
# Function to read the labels file (1D)
|
|
def read_labels(filename):
|
|
with gzip.open(filename, 'rb') as f:
|
|
# Skip the first 8 bytes (magic number and label count)
|
|
f.read(8)
|
|
# Read the labels (remaining data)
|
|
return np.frombuffer(f.read(), dtype=np.uint8)
|
|
|
|
# Load train labels
|
|
train_labels = read_labels('dataset/train-labels-idx1-ubyte.gz')
|
|
|
|
# Find all indices of the digit '9' in the training labels
|
|
indices_of_nines = np.where(train_labels == 9)[0]
|
|
|
|
# Count how many '9's are there in the training set
|
|
num_nines = len(indices_of_nines)
|
|
|
|
# Output the count and the first 10 indices
|
|
print(f"Number of '9's in the training set: {num_nines}")
|
|
print(f"First 10 indices of '9': {indices_of_nines[:10]}")
|