You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
391 lines
12 KiB
391 lines
12 KiB
{ |
|
"cells": [ |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"#https://www.tensorflow.org/tutorials/keras/classification\n", |
|
"\n", |
|
"from __future__ import absolute_import, division, print_function, unicode_literals\n", |
|
"\n", |
|
"# TensorFlow and tf.keras\n", |
|
"import tensorflow as tf\n", |
|
"from tensorflow import keras\n", |
|
"from tensorflow.keras import datasets, layers, models\n", |
|
"\n", |
|
"# Helper libraries\n", |
|
"import numpy as np\n", |
|
"import matplotlib.pyplot as plt\n", |
|
"\n", |
|
"import json" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"#DATA_FILE=\"/home/johanv/downloads/labelled.json\"\n", |
|
"DATA_FILE=\"/home/johanv/nextcloud/projects/boggle2.0/labelled.json\"\n", |
|
"#DATA_FILE=\"/home/johanv/nextcloud/projects/boggle2.0/labelled-20200124_155523.mp4.json\"\n", |
|
"\n", |
|
"#import os\n", |
|
"#DATA_FILE = \"/labelled.json\"\n", |
|
"#DATA_URL = \"https://drive.confuzer.cloud/index.php/s/2fQ5KkGi3Bi2SLq/download\"\n", |
|
"#os.system(\"curl \" + DATA_URL + \" > \" + DATA_FILE)\n", |
|
"\n", |
|
"MODEL_SAVE_FILE=\"/home/johanv/nextcloud/projects/boggle2.0/model.h5\"\n" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"print(tf.__version__)\n", |
|
"\n", |
|
"#fashion_mnist = keras.datasets.fashion_mnist\n", |
|
"#(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n", |
|
"#class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n", |
|
"# 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n", |
|
"\n", |
|
"with open(DATA_FILE, 'r') as f:\n", |
|
" data = json.load(f)\n", |
|
"\n", |
|
"IMG_DIM = 30\n", |
|
"\n", |
|
"images_in = data[\"imgs\"]\n", |
|
"labels_in = data[\"labels\"]\n", |
|
"\n", |
|
"images = []\n", |
|
"labels = []\n", |
|
"for rot in range(4):\n", |
|
" for i in range(len(images_in)):\n", |
|
" #https://artemrudenko.wordpress.com/2014/08/28/python-rotate-2d-arraymatrix-90-degrees-one-liner/\n", |
|
" images_in[i] = list(zip(*images_in[i][::-1])) #rotate 90 degrees (still the same letter!)\n", |
|
" images.extend(images_in)\n", |
|
" labels.extend(labels_in)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"images = np.array(images, dtype=np.uint8).reshape((-1, IMG_DIM, IMG_DIM, 1))\n", |
|
"\n", |
|
"split = int(0.15 * len(images))\n", |
|
"\n", |
|
"train_images = np.array(images[split:],dtype=np.uint8)\n", |
|
"train_labels = np.array(labels[split:],dtype=np.uint8)\n", |
|
"\n", |
|
"test_images = np.array(images[:split],dtype=np.uint8)\n", |
|
"test_labels = np.array(labels[:split],dtype=np.uint8)\n", |
|
"\n", |
|
"class_names = \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n", |
|
"\n", |
|
"L = len(class_names)\n", |
|
"\n", |
|
"print(train_images.shape)\n", |
|
"print(train_labels.shape)\n", |
|
"print(test_images.shape)\n", |
|
"print(test_labels.shape)\n", |
|
"\n", |
|
"# plt.figure()\n", |
|
"# plt.imshow(train_images[1])\n", |
|
"# plt.colorbar()\n", |
|
"# plt.grid(False)\n", |
|
"# plt.show()\n", |
|
"\n", |
|
"train_images = train_images / 255.0\n", |
|
"test_images = test_images / 255.0\n" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"plt.figure(figsize=(10,10))\n", |
|
"for i in range(25):\n", |
|
" print(\"{}/25\".format(i))\n", |
|
" plt.subplot(5,5,i+1)\n", |
|
" plt.xticks([])\n", |
|
" plt.yticks([])\n", |
|
" plt.grid(False)\n", |
|
" plt.imshow(train_images[i].reshape((IMG_DIM, IMG_DIM)), cmap=plt.cm.binary)\n", |
|
" plt.xlabel(class_names[train_labels[i]])\n", |
|
"plt.show()\n" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"#model = keras.Sequential([\n", |
|
"# keras.layers.Flatten(input_shape=(IMG_DIM,IMG_DIM)),\n", |
|
"# keras.layers.Dense(128, activation=\"sigmoid\"),\n", |
|
"# keras.layers.Dense(64, activation=\"sigmoid\"),\n", |
|
"# keras.layers.Dense(L, activation=\"softmax\")\n", |
|
"#])\n", |
|
"\n", |
|
"model = models.Sequential()\n", |
|
"model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(IMG_DIM, IMG_DIM, 1)))\n", |
|
"model.add(layers.MaxPooling2D((2, 2)))\n", |
|
"model.add(layers.Conv2D(32, (3, 3), activation='relu'))\n", |
|
"model.add(layers.MaxPooling2D((2, 2)))\n", |
|
"model.add(layers.Conv2D(16, (3, 3), activation='relu'))\n", |
|
"\n", |
|
"model.add(layers.Flatten())\n", |
|
"model.add(layers.Dense(100, activation='relu'))\n", |
|
"# model.add(layers.Dense(52, activation='relu'))\n", |
|
"model.add(layers.Dense(L, activation=\"softmax\"))\n", |
|
"\n", |
|
"print(model.summary())" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"#model.compile(\n", |
|
"# loss='sparse_categorical_crossentropy',\n", |
|
"# optimizer='adam',\n", |
|
"# metrics=['accuracy']\n", |
|
"#)\n", |
|
"#\n", |
|
"#model.fit(train_images, train_labels, epochs=7)\n", |
|
"\n", |
|
"model.compile(optimizer='adam',\n", |
|
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
|
" metrics=['accuracy'])\n", |
|
"\n", |
|
"history = model.fit(train_images, train_labels, epochs=8,\n", |
|
" validation_data=(test_images, test_labels))\n" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"model.save(MODEL_SAVE_FILE)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"model = tf.keras.models.load_model(MODEL_SAVE_FILE)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"print(history)\n", |
|
"\n", |
|
"plt.plot(history.history['accuracy'], label='accuracy')\n", |
|
"plt.plot(history.history['val_accuracy'], label = 'val_accuracy')\n", |
|
"plt.xlabel('Epoch')\n", |
|
"plt.ylabel('Accuracy')\n", |
|
"plt.ylim([0.5, 1])\n", |
|
"plt.legend(loc='lower right')\n", |
|
"\n", |
|
"test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)\n", |
|
"\n", |
|
"print(\"\\n\\nacc: \", test_acc)\n", |
|
"\n", |
|
"predictions = model.predict(test_images)\n", |
|
"print(\"predictions[0]: \", predictions[0])\n", |
|
"print(\"np.argmax(predictions[0]): \", np.argmax(predictions[0]))\n", |
|
"print(\"test_labels[0]: \", test_labels[0])" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"def get_label(predictions_array):\n", |
|
" predicted_label = np.argmax(predictions_array)\n", |
|
" if class_names[predicted_label] == 'E' and 100*np.max(predictions_array) < 99.4:\n", |
|
" predicted_label = class_names.index('Q')\n", |
|
" if class_names[predicted_label] == 'T' and 100*np.max(predictions_array) < 96:\n", |
|
" predicted_label = class_names.index('L')\n", |
|
" if class_names[predicted_label] == 'U' and 100*np.max(predictions_array) < 67:\n", |
|
" predicted_label = class_names.index('L')\n", |
|
" if class_names[predicted_label] == 'R' and 100*np.max(predictions_array) < 85:\n", |
|
" predicted_label = class_names.index('U')\n", |
|
" return predicted_label\n", |
|
"\n", |
|
"def plot_image(i, predictions_array, true_label, img):\n", |
|
" predictions_array, true_label, img = predictions_array, true_label[i], img[i]\n", |
|
" plt.grid(False)\n", |
|
" plt.xticks([])\n", |
|
" plt.yticks([])\n", |
|
"\n", |
|
" plt.imshow(img.reshape((IMG_DIM, IMG_DIM)), cmap=plt.cm.binary)\n", |
|
"\n", |
|
" predicted_label = get_label(predictions_array)\n", |
|
" if predicted_label == true_label:\n", |
|
" color = 'blue'\n", |
|
" else:\n", |
|
" color = 'red'\n", |
|
"\n", |
|
" plt.xlabel(\"{} {:2.2f}% ({})\".format(class_names[predicted_label],\n", |
|
" 100*np.max(predictions_array),\n", |
|
" class_names[true_label]),\n", |
|
" color=color)\n", |
|
"\n", |
|
"def plot_value_array(i, predictions_array, true_label):\n", |
|
" predictions_array, true_label = predictions_array, true_label[i]\n", |
|
" plt.grid(False)\n", |
|
" plt.xticks(range(L))\n", |
|
" plt.yticks([])\n", |
|
" thisplot = plt.bar(range(L), predictions_array, color=\"#777777\")\n", |
|
" plt.ylim([0, 1])\n", |
|
" predicted_label = get_label(predictions_array)\n", |
|
"\n", |
|
" thisplot[predicted_label].set_color('red')\n", |
|
" thisplot[true_label].set_color('blue')" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"i = 0\n", |
|
"plt.figure(figsize=(6,3))\n", |
|
"plt.subplot(1,2,1)\n", |
|
"plot_image(i, predictions[i], test_labels, test_images)\n", |
|
"plt.subplot(1,2,2)\n", |
|
"plot_value_array(i, predictions[i], test_labels)\n", |
|
"plt.show()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"i = 12\n", |
|
"plt.figure(figsize=(6,3))\n", |
|
"plt.subplot(1,2,1)\n", |
|
"plot_image(i, predictions[i], test_labels, test_images)\n", |
|
"plt.subplot(1,2,2)\n", |
|
"plot_value_array(i, predictions[i], test_labels)\n", |
|
"plt.show()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Plot the first X test images, their predicted labels, and the true labels.\n", |
|
"# Color correct predictions in blue and incorrect predictions in red.\n", |
|
"for j in range(30):\n", |
|
" num_rows = 5\n", |
|
" num_cols = 3\n", |
|
" num_images = num_rows*num_cols\n", |
|
" plt.figure(figsize=(2*2*num_cols, 2*num_rows))\n", |
|
" for i in range(num_images):\n", |
|
" plt.subplot(num_rows, 2*num_cols, 2*i+1)\n", |
|
" plot_image(i+j*num_images, predictions[i+j*num_images], test_labels, test_images)\n", |
|
" plt.subplot(num_rows, 2*num_cols, 2*i+2)\n", |
|
" plot_value_array(i+j*num_images, predictions[i+j*num_images], test_labels)\n", |
|
" plt.tight_layout()\n", |
|
" plt.show()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Grab an image from the test dataset.\n", |
|
"img = test_images[1]\n", |
|
"\n", |
|
"print(img.shape)\n", |
|
"\n", |
|
"# Add the image to a batch where it's the only member.\n", |
|
"img = (np.expand_dims(img,0))\n", |
|
"\n", |
|
"print(img.shape)\n", |
|
"\n", |
|
"predictions_single = model.predict(img)\n", |
|
"\n", |
|
"print(predictions_single)\n", |
|
"\n", |
|
"plot_value_array(1, predictions_single[0], test_labels)\n", |
|
"_ = plt.xticks(range(L), class_names, rotation=45)\n", |
|
"\n", |
|
"print(\"np.argmax(predictions_single[0]): \", np.argmax(predictions_single[0]))" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [] |
|
} |
|
], |
|
"metadata": { |
|
"kernelspec": { |
|
"display_name": "Python 3", |
|
"language": "python", |
|
"name": "python3" |
|
}, |
|
"language_info": { |
|
"codemirror_mode": { |
|
"name": "ipython", |
|
"version": 3 |
|
}, |
|
"file_extension": ".py", |
|
"mimetype": "text/x-python", |
|
"name": "python", |
|
"nbconvert_exporter": "python", |
|
"pygments_lexer": "ipython3", |
|
"version": "3.8.5" |
|
} |
|
}, |
|
"nbformat": 4, |
|
"nbformat_minor": 2 |
|
}
|
|
|