{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#Original Author: Jonathan Hudson\n", "#CPSC 501 F22" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DhZTTSlltFh2" }, "source": [ "Imports that are needed" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "KsdVGfVCnQ4J" }, "outputs": [], "source": [ "import sys\n", "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "huG1gu0KtG_a" }, "source": [ "Determine the arguments\n", "A weird way to do a notebook but lets code match non notebook code" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "ukCbw8ydnP3k" }, "outputs": [ { "name": "stdin", "output_type": "stream", "text": [ "Dataset: notMNIST\n", "Model: notMNIST.h5\n", "Image: image.png\n", "Class index: 1\n" ] } ], "source": [ "sys.argv = [\"\", input(\"Dataset:\"), input(\"Model:\"), input(\"Image:\"), input(\"Class index:\")]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "6TD50Q9e_0ZE" }, "outputs": [], "source": [ "def check_args():\n", " if(len(sys.argv) != 5):\n", " print(\"Usage python predict.py \")\n", " sys.exit(1)\n", " if sys.argv[1] == \"MNIST\":\n", " print(\"--Dataset MNIST--\")\n", " class_names = list(range(10))\n", " elif sys.argv[1] == \"notMNIST\":\n", " print(\"--Dataset notMNIST--\")\n", " class_names = [\"A\",\"B\",\"C\",\"D\",\"E\",\"F\",\"G\",\"H\",\"I\",\"J\"]\n", " else:\n", " print(f\"Choose MNIST or notMNIST, not {sys.argv[1]}\")\n", " sys.exit(2)\n", " if sys.argv[2][-3:] != \".h5\":\n", " print(f\"{sys.argv[2]} is not a h5 extension\")\n", " sys.exit(3)\n", " if sys.argv[3][-4:] != \".png\":\n", " print(f\"{sys.argv[3]} is not a png extension\")\n", " sys.exit(3)\n", " img = plt.imread(sys.argv[3])\n", " if len(img.shape) != 2:\n", " print(\"Image is not grey scale!\")\n", " sys.exit(4)\n", " if img.shape != (28,28):\n", " print(\"Image is not 28 by 28!\")\n", " sys.exit(4)\n", " if not sys.argv[4].isdigit():\n", " print(f\"{sys.argv[4]} is not an integer (0-9)\")\n", " sys.exit(3)\n", " if int(sys.argv[4]) < 0 or int(sys.argv[4]) > 9 :\n", " print(f\"{sys.argv[4]} is not an integer (0-9)\")\n", " sys.exit(3)\n", " return class_names" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GFqxpnv4tjR6" }, "source": [ "Completed functions to plot for you" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "E75I4-RVeme5" }, "outputs": [], "source": [ "def plot(class_names, prediction, true_label, predicted_label, img):\n", " plt.figure(figsize=(6,3))\n", " plt.subplot(1,2,1)\n", " plt.grid(False)\n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.imshow(img, cmap=plt.cm.binary)\n", " predicted_label = np.argmax(prediction)\n", " if predicted_label == true_label:\n", " color = 'blue'\n", " else:\n", " color = 'red'\n", " plt.xlabel(\"{} {:2.0f}% ({})\".format(class_names[predicted_label],100*np.max(prediction),class_names[true_label]),color=color)\n", " plt.subplot(1,2,2)\n", " plt.grid(False)\n", " plt.xticks(range(10))\n", " plt.yticks([])\n", " thisplot = plt.bar(class_names, prediction, color=\"#777777\")\n", " plt.ylim([0, 1])\n", " thisplot[predicted_label].set_color('red')\n", " thisplot[true_label].set_color('blue')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "kL_b5WbYtonX" }, "source": [ "Finish the missing parts of function to predict output from input" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "GChTgYKGtoFN" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--Dataset notMNIST--\n" ] }, { "ename": "FileNotFoundError", "evalue": "[Errno 2] No such file or directory: 'image.png'", "output_type": "error", "traceback": [ "\u001b[1;31m--------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn [5], line 22\u001b[0m\n\u001b[0;32m 20\u001b[0m plot(class_names, prediction, true_label, predicted_label, img[\u001b[38;5;241m0\u001b[39m])\n\u001b[0;32m 21\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n\u001b[1;32m---> 22\u001b[0m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "Cell \u001b[1;32mIn [5], line 2\u001b[0m, in \u001b[0;36mmain\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m():\n\u001b[1;32m----> 2\u001b[0m class_names \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_args\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m--Load Model \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msys\u001b[38;5;241m.\u001b[39margv[\u001b[38;5;241m2\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m--\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 4\u001b[0m \u001b[38;5;66;03m#Load the model that should be in sys.argv[2]\u001b[39;00m\n", "Cell \u001b[1;32mIn [3], line 20\u001b[0m, in \u001b[0;36mcheck_args\u001b[1;34m()\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msys\u001b[38;5;241m.\u001b[39margv[\u001b[38;5;241m3\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is not a png extension\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 19\u001b[0m sys\u001b[38;5;241m.\u001b[39mexit(\u001b[38;5;241m3\u001b[39m)\n\u001b[1;32m---> 20\u001b[0m img \u001b[38;5;241m=\u001b[39m \u001b[43mplt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimread\u001b[49m\u001b[43m(\u001b[49m\u001b[43msys\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margv\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 21\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(img\u001b[38;5;241m.\u001b[39mshape) \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[0;32m 22\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mImage is not grey scale!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[1;32mD:\\env\\lib\\site-packages\\matplotlib\\pyplot.py:2113\u001b[0m, in \u001b[0;36mimread\u001b[1;34m(fname, format)\u001b[0m\n\u001b[0;32m 2111\u001b[0m \u001b[38;5;129m@_copy_docstring_and_deprecators\u001b[39m(matplotlib\u001b[38;5;241m.\u001b[39mimage\u001b[38;5;241m.\u001b[39mimread)\n\u001b[0;32m 2112\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mimread\u001b[39m(fname, \u001b[38;5;28mformat\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m-> 2113\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmatplotlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32mD:\\env\\lib\\site-packages\\matplotlib\\image.py:1541\u001b[0m, in \u001b[0;36mimread\u001b[1;34m(fname, format)\u001b[0m\n\u001b[0;32m 1534\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(fname, \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(parse\u001b[38;5;241m.\u001b[39murlparse(fname)\u001b[38;5;241m.\u001b[39mscheme) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m 1535\u001b[0m \u001b[38;5;66;03m# Pillow doesn't handle URLs directly.\u001b[39;00m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease open the URL for reading and pass the \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mresult to Pillow, e.g. with \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m``np.array(PIL.Image.open(urllib.request.urlopen(url)))``.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1540\u001b[0m )\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mimg_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfname\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m image:\n\u001b[0;32m 1542\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (_pil_png_to_float_array(image)\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(image, PIL\u001b[38;5;241m.\u001b[39mPngImagePlugin\u001b[38;5;241m.\u001b[39mPngImageFile) \u001b[38;5;28;01melse\u001b[39;00m\n\u001b[0;32m 1544\u001b[0m pil_to_array(image))\n", "File \u001b[1;32mD:\\env\\lib\\site-packages\\PIL\\ImageFile.py:104\u001b[0m, in \u001b[0;36mImageFile.__init__\u001b[1;34m(self, fp, filename)\u001b[0m\n\u001b[0;32m 100\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecodermaxblock \u001b[38;5;241m=\u001b[39m MAXBLOCK\n\u001b[0;32m 102\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_path(fp):\n\u001b[0;32m 103\u001b[0m \u001b[38;5;66;03m# filename\u001b[39;00m\n\u001b[1;32m--> 104\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 105\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfilename \u001b[38;5;241m=\u001b[39m fp\n\u001b[0;32m 106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_exclusive_fp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", "\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'image.png'" ] } ], "source": [ "def main():\n", " class_names = check_args()\n", " print(f\"--Load Model {sys.argv[2]}--\")\n", " #Load the model that should be in sys.argv[2]\n", " model = None\n", " print(f\"--Load Image {sys.argv[3]}--\")\n", " img = plt.imread(sys.argv[3])\n", " if np.amax(img.flatten()) > 1:\n", " img = img/255\n", " img = 1 - img\n", " print(f\"--Predict as Class {sys.argv[4]}--\")\n", " predict(model, class_names, img, int(sys.argv[4]))\n", "\n", "def predict(model, class_names, img, true_label):\n", " img = np.array([img])\n", " #Replace these two lines with code to make a prediction\n", " prediction = [1/10,1/10,1/10,1/10,1/10,1/10,1/10,1/10,1/10,1/10]\n", " #Determine what the predicted label is\n", " predicted_label = 0\n", " plot(class_names, prediction, true_label, predicted_label, img[0])\n", " plt.show()\n", "main()" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "predict.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 4 }