Skip to content

Commit c4bd2a9

Browse files
authored
Add pydot package (#956)
This is used by keras plot_model() function. Included a test to prevent regression. Improved the existing tf test with extra assertions. The issue was raised by a user here: https://www.kaggle.com/c/jane-street-market-prediction/discussion/214494#1184233
1 parent ddc0c00 commit c4bd2a9

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

Dockerfile

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ RUN apt-get install -y libfreetype6-dev && \
8080
pip install xgboost && \
8181
# Pinned to match GPU version. Update version together.
8282
pip install lightgbm==3.1.1 && \
83+
pip install pydot && \
8384
pip install keras && \
8485
pip install keras-tuner && \
8586
pip install flake8 && \

tests/test_tensorflow.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import os.path
23

34
import numpy as np
45
import tensorflow as tf
@@ -37,8 +38,13 @@ def test_tf_keras(self):
3738
metrics=['accuracy'])
3839

3940
model.fit(x_train, y_train, epochs=1)
40-
model.evaluate(x_test, y_test)
41-
41+
42+
result = model.evaluate(x_test, y_test)
43+
self.assertEqual(2, len(result))
44+
45+
# exercices pydot path.
46+
tf.keras.utils.plot_model(model, to_file="tf_plot_model.png")
47+
self.assertTrue(os.path.isfile("tf_plot_model.png"))
4248

4349
def test_lstm(self):
4450
x_train = np.random.random((100, 28, 28))
@@ -58,7 +64,8 @@ def test_lstm(self):
5864
metrics=['accuracy'])
5965

6066
model.fit(x_train, y_train, epochs=1)
61-
model.evaluate(x_test, y_test)
67+
result = model.evaluate(x_test, y_test)
68+
self.assertEqual(2, len(result))
6269

6370
@gpu_test
6471
def test_gpu(self):

0 commit comments

Comments
 (0)