-
-
Save ColeMurray/bf25aad332b3074c2776f4b0112f3947 to your computer and use it in GitHub Desktop.
Load train, validation, and test set for mnist. Input is reshaped into Nx28x28
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
IMAGE_SIZE = 28 | |
def load_train_data(data_path, validation_size=500): | |
""" | |
Load mnist data. Each row in csv is formatted (label, input) | |
:return: 3D Tensor input of train and validation set with 2D Tensor of one hot encoded image labels | |
""" | |
# Data format: 1 byte label, 28 * 28 input | |
train_data = np.genfromtxt(data_path, delimiter=',', dtype=np.float32) | |
x_train = train_data[:, 1:] | |
# Get label and one-hot encode | |
y_train = train_data[:, 0] | |
y_train = (np.arange(10) == y_train[:, None]).astype(np.float32) | |
# get a validation set and remove it from the train set | |
x_train, x_val, y_train, y_val = x_train[0:(len(x_train) - validation_size), :], x_train[( | |
len(x_train) - validation_size):len(x_train), :], \ | |
y_train[0:(len(y_train) - validation_size), :], y_train[( | |
len(y_train) - validation_size):len(y_train), :] | |
# reformat the data so it's not flat | |
x_train = x_train.reshape(len(x_train), IMAGE_SIZE, IMAGE_SIZE, 1) | |
x_val = x_val.reshape(len(x_val), IMAGE_SIZE, IMAGE_SIZE, 1) | |
return x_train, x_val, y_train, y_val | |
def load_test_data(data_path): | |
""" | |
Load mnist test data | |
:return: 3D Tensor input of train and validation set with 2D Tensor of one hot encoded image labels | |
""" | |
test_data = np.genfromtxt(data_path, delimiter=',', dtype=np.float32) | |
x_test = test_data[:, 1:] | |
y_test = np.array(test_data[:, 0]) | |
y_test = (np.arange(10) == y_test[:, None]).astype(np.float32) | |
x_test = x_test.reshape(len(x_test), IMAGE_SIZE, IMAGE_SIZE, 1) | |
return x_test, y_test |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment