Skip to content

Commit e24933c

Browse files
authored
Add files via upload
1 parent 0bb1d6f commit e24933c

File tree

1 file changed

+171
-11
lines changed

1 file changed

+171
-11
lines changed

torcu/model.ipynb

+171-11
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,9 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 2,
5+
"execution_count": 31,
66
"metadata": {},
7-
"outputs": [
8-
{
9-
"name": "stderr",
10-
"output_type": "stream",
11-
"text": [
12-
"Using TensorFlow backend.\n"
13-
]
14-
}
15-
],
7+
"outputs": [],
168
"source": [
179
"import os\n",
1810
"import matplotlib\n",
@@ -26,7 +18,9 @@
2618
"import shutil\n",
2719
"from keras.preprocessing.image import ImageDataGenerator\n",
2820
"from keras.models import Sequential\n",
29-
"from keras.metrics import top_k_categorical_accuracy"
21+
"from keras.metrics import top_k_categorical_accuracy\n",
22+
"from sklearn.metrics import classification_report, confusion_matrix\n",
23+
"from itertools import product"
3024
]
3125
},
3226
{
@@ -348,6 +342,172 @@
348342
"print('Test Loss:', score[0])\n",
349343
"print('Test accuracy:', score[1])"
350344
]
345+
},
346+
{
347+
"cell_type": "code",
348+
"execution_count": 24,
349+
"metadata": {
350+
"collapsed": true
351+
},
352+
"outputs": [],
353+
"source": [
354+
"def plot_confusion_matrix(cm, classes,\n",
355+
" normalize=False,\n",
356+
" title='Confusion matrix',\n",
357+
" cmap=plt.cm.Blues):\n",
358+
" \"\"\"\n",
359+
" This function prints and plots the confusion matrix.\n",
360+
" Normalization can be applied by setting `normalize=True`.\n",
361+
" \"\"\"\n",
362+
" if normalize:\n",
363+
" cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
364+
" print(\"Normalized confusion matrix\")\n",
365+
" else:\n",
366+
" print('Confusion matrix, without normalization')\n",
367+
"\n",
368+
" print(cm)\n",
369+
"\n",
370+
" plt.figure(figsize=(40,40)) \n",
371+
" plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
372+
" plt.title(title)\n",
373+
" plt.colorbar()\n",
374+
" tick_marks = np.arange(len(classes))\n",
375+
" plt.xticks(tick_marks, classes, rotation=90)\n",
376+
" plt.yticks(tick_marks, classes)\n",
377+
"\n",
378+
" fmt = '.2f' if normalize else 'd'\n",
379+
" thresh = cm.max() / 2.\n",
380+
" for i, j in product(range(cm.shape[0]), range(cm.shape[1])):\n",
381+
" plt.text(j, i, format(cm[i, j], fmt),\n",
382+
" horizontalalignment=\"center\",\n",
383+
" color=\"white\" if cm[i, j] > thresh else \"black\")\n",
384+
"\n",
385+
" plt.tight_layout()\n",
386+
" plt.ylabel('True label')\n",
387+
" plt.xlabel('Predicted label')"
388+
]
389+
},
390+
{
391+
"cell_type": "code",
392+
"execution_count": 26,
393+
"metadata": {},
394+
"outputs": [
395+
{
396+
"name": "stdout",
397+
"output_type": "stream",
398+
"text": [
399+
"_________________________________________________________________\n",
400+
"Layer (type) Output Shape Param # \n",
401+
"=================================================================\n",
402+
"conv2d_6 (Conv2D) (None, 62, 62, 32) 896 \n",
403+
"_________________________________________________________________\n",
404+
"conv2d_7 (Conv2D) (None, 60, 60, 32) 9248 \n",
405+
"_________________________________________________________________\n",
406+
"conv2d_8 (Conv2D) (None, 58, 58, 32) 9248 \n",
407+
"_________________________________________________________________\n",
408+
"max_pooling2d_3 (MaxPooling2 (None, 29, 29, 32) 0 \n",
409+
"_________________________________________________________________\n",
410+
"dropout_4 (Dropout) (None, 29, 29, 32) 0 \n",
411+
"_________________________________________________________________\n",
412+
"conv2d_9 (Conv2D) (None, 27, 27, 32) 9248 \n",
413+
"_________________________________________________________________\n",
414+
"conv2d_10 (Conv2D) (None, 25, 25, 32) 9248 \n",
415+
"_________________________________________________________________\n",
416+
"max_pooling2d_4 (MaxPooling2 (None, 12, 12, 32) 0 \n",
417+
"_________________________________________________________________\n",
418+
"dropout_5 (Dropout) (None, 12, 12, 32) 0 \n",
419+
"_________________________________________________________________\n",
420+
"flatten_2 (Flatten) (None, 4608) 0 \n",
421+
"_________________________________________________________________\n",
422+
"dense_3 (Dense) (None, 256) 1179904 \n",
423+
"_________________________________________________________________\n",
424+
"dropout_6 (Dropout) (None, 256) 0 \n",
425+
"_________________________________________________________________\n",
426+
"dense_4 (Dense) (None, 62) 15934 \n",
427+
"=================================================================\n",
428+
"Total params: 1,233,726\n",
429+
"Trainable params: 1,233,726\n",
430+
"Non-trainable params: 0\n",
431+
"_________________________________________________________________\n"
432+
]
433+
},
434+
{
435+
"data": {
436+
"text/plain": [
437+
"True"
438+
]
439+
},
440+
"execution_count": 26,
441+
"metadata": {},
442+
"output_type": "execute_result"
443+
}
444+
],
445+
"source": [
446+
"x.summary()\n",
447+
"x.get_config()\n",
448+
"x.layers[0].get_config()\n",
449+
"x.layers[0].input_shape\t\t\t\n",
450+
"x.layers[0].output_shape\t\t\t\n",
451+
"x.layers[0].get_weights()\n",
452+
"np.shape(x.layers[0].get_weights()[0])\n",
453+
"x.layers[0].trainable"
454+
]
455+
},
456+
{
457+
"cell_type": "code",
458+
"execution_count": 27,
459+
"metadata": {
460+
"collapsed": true
461+
},
462+
"outputs": [],
463+
"source": [
464+
"labels = [\"Uneven road\", \"Speed bump\", \"Slippery road\", \"Dangerous curve to the left\", \"Dangerous curve to the right\", \"Double dangerous curve to the left\", \"Double dangerous curve to the right\", \"Presence of children\", \"Bicycle crossing\", \"Cattle crossing\", \"Road works ahead\", \"Traffic signals ahead\", \"Guarded railroad crossing\", \"Indefinite danger\", \"Road narrows\", \"Road narrows from the left\", \"Road narrows from the right\", \"Priority at the next intersection\", \"Intersection where the priority from the right is applicable\", \"Yield right of way\", \"Yield to oncoming traffic\", \"Stop\", \"No entry for all drivers\", \"No bicycles allowed\", \"Maximum weights allowed (including load)\", \"No cargo vehicles allowed\", \"Maximum width allowed\", \"Maximum height allowed\", \"No traffic allowed in both directions\", \"No left turn\", \"No right turn\", \"No passing to the left vehicles having more than 2 wheels and horse drawn vehicles\", \"Maximum speed limit\", \"Mandatory way for pedestrians and bicycles\", \"Mandatory direction (straight on)\", \"Mandatory direction (to the right or to the left)\", \"Mandatory directions(straight on and to the right)\", \"Mandatory traffic circle\", \"Mandatory bicycle path\", \"Path shared between pedestrians, bicycles and mopeds class A\", \"No parking\", \"No waiting or parking\", \"No parking from the 1 st to the 15th of the month\", \"No parking from the 16th till the end of the month\", \"Priority over oncoming traffic\", \"Parking allowed\", \"Additional parking sign for handicap only\", \"Parking exclusively for motorcars\", \"Parking exclusively for trucks\", \"Parking exclusively for buses/coaches\", \"Parking on sidewalk or verge mandatory\", \"Beginning of a residential area\", \"End of a residential area\", \"One way traffic\", \"Dead end\", \"End of road works\", \"Pedestrian crosswalk\", \"Bicycles and mopeds crossing\", \"Parking ahead\", \"Speed bump\", \"End of priority road\", \"Priority road\"]"
465+
]
466+
},
467+
{
468+
"cell_type": "code",
469+
"execution_count": 32,
470+
"metadata": {},
471+
"outputs": [
472+
{
473+
"name": "stdout",
474+
"output_type": "stream",
475+
"text": [
476+
"Normalized confusion matrix\n",
477+
"[[ 0. 0.5 0. ..., 0. 0. 0. ]\n",
478+
" [ 0. 0.51851852 0. ..., 0. 0. 0. ]\n",
479+
" [ 0. 0.42857143 0. ..., 0. 0. 0. ]\n",
480+
" ..., \n",
481+
" [ 0. 0. 0. ..., 0. 0. 0.05882353]\n",
482+
" [ 0. 0. 0. ..., 0. 0. 1. ]\n",
483+
" [ 0. 0. 0. ..., 0. 0. 1. ]]\n"
484+
]
485+
}
486+
],
487+
"source": [
488+
"test_set.reset()\n",
489+
"probabilities = x.predict_generator(test_set)\n",
490+
"y_pred = np.argmax(probabilities , axis=-1)\n",
491+
"\n",
492+
"cnf_matrix = confusion_matrix(test_set.classes, y_pred)\n",
493+
"plot_confusion_matrix(cnf_matrix, labels, True)"
494+
]
495+
},
496+
{
497+
"cell_type": "code",
498+
"execution_count": null,
499+
"metadata": {},
500+
"outputs": [],
501+
"source": []
502+
},
503+
{
504+
"cell_type": "code",
505+
"execution_count": null,
506+
"metadata": {
507+
"collapsed": true
508+
},
509+
"outputs": [],
510+
"source": []
351511
}
352512
],
353513
"metadata": {

0 commit comments

Comments
 (0)