Skip to content

Commit dc4be85

Browse files
committed
First implemententation supervised contrastive learning (in keras).
1 parent 88a4742 commit dc4be85

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

supervised_contrastive.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import fire
2+
import umap
3+
import tensorflow as tf
4+
import tensorflow_addons as tfa
5+
6+
from tensorflow import keras
7+
import matplotlib.pyplot as plt
8+
9+
print(tf.__version__)
10+
11+
12+
class SupervisedContrastiveLoss(keras.losses.Loss):
13+
def __init__(self, temperature=1.0, name=None):
14+
super(SupervisedContrastiveLoss, self).__init__(name=name)
15+
self.temperature = temperature
16+
17+
def __call__(self, labels, feature_vectors, sample_weight=None):
18+
feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)
19+
20+
dot_product = tf.matmul(
21+
feature_vectors_normalized, tf.transpose(feature_vectors_normalized)
22+
)
23+
24+
logits = tf.divide(
25+
dot_product, self.temperature
26+
)
27+
28+
return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
29+
30+
31+
class SupervisedContrastiveLearner:
32+
33+
def __init__(self):
34+
self.encoder = None
35+
self.epochs = 1
36+
self.batch_size = 16
37+
self.num_classes = 10
38+
self.input_shape = (32, 32, 3)
39+
self.embedding_dim = 128
40+
self.temperature = 0.05
41+
self.dropout = 0.2
42+
self.lr = 0.01
43+
44+
self.encoder_path = "./models/supervised_contrastive_encoder"
45+
46+
self.train_data = None
47+
self.test_data = None
48+
49+
def load_data(self):
50+
self.train_data, self.test_data = keras.datasets.cifar10.load_data()
51+
print(f"Data loaded. Train shape: {self.train_data[0].shape}, "
52+
f"Test shape: {self.test_data[0].shape}")
53+
54+
def create_encoder(self):
55+
spine = keras.applications.EfficientNetB0(
56+
include_top=False, weights=None, input_shape=self.input_shape, pooling="avg"
57+
)
58+
59+
inputs = keras.Input(shape=self.input_shape)
60+
features = spine(inputs)
61+
outputs = keras.layers.Dense(self.embedding_dim, activation="relu")(features)
62+
model = keras.Model(inputs=inputs, outputs=outputs, name="supervised_contrastive_encoder")
63+
64+
return model
65+
66+
def train(self):
67+
# Load data
68+
self.load_data()
69+
70+
# Create encoder
71+
encoder = self.create_encoder()
72+
encoder.summary()
73+
74+
# Compile encoder
75+
encoder.compile(
76+
optimizer=keras.optimizers.Adam(self.lr),
77+
loss=SupervisedContrastiveLoss(self.temperature),
78+
)
79+
80+
# Train encoder
81+
x_train, y_train = self.train_data[0], self.train_data[1]
82+
# keras.backend.clear_session()
83+
encoder.fit(x=x_train, y=y_train, batch_size=self.batch_size, epochs=self.epochs)
84+
85+
# Save model
86+
encoder.save(self.encoder_path)
87+
88+
def visualize_embeddings(self):
89+
# Load data
90+
self.load_data()
91+
92+
# Load model
93+
encoder = keras.models.load_model(self.encoder_path, compile=False)
94+
encoder.compile(
95+
optimizer=keras.optimizers.Adam(self.lr),
96+
loss=SupervisedContrastiveLoss(self.temperature),
97+
)
98+
99+
# Compute embeddings
100+
x, y = self.test_data[0], self.test_data[1]
101+
embeddings = encoder.predict(x)
102+
print(f"Encoder embedding shape: {embeddings.shape}")
103+
104+
# UMAP
105+
reducer = umap.UMAP()
106+
umap_embeddings = reducer.fit_transform(embeddings)
107+
print(f"UMAP embedding shape: {umap_embeddings.shape}")
108+
109+
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], c=y)
110+
plt.title("UMAP for CIFAR-10")
111+
plt.show()
112+
113+
114+
if __name__ == "__main__":
115+
fire.Fire(SupervisedContrastiveLearner)

0 commit comments

Comments
 (0)