Skip to content

Commit

Permalink
Added examples for tensorflow types in Datatypes and IO section (#1739)
Browse files Browse the repository at this point in the history
* Added examples for tensorflow types in Datatypes and IO section

Signed-off-by: sumana sree <[email protected]>

* Fixed linting errors

Signed-off-by: sumana sree <[email protected]>

* updated tensorflow_type.py file to avoid linting errors

Signed-off-by: sumana sree <[email protected]>

* Apply lint corrections using pre-commit hooks

Signed-off-by: sumana sree <[email protected]>

* added required comments

Signed-off-by: sumana sree <[email protected]>

* fixed error on importing tensorflow

Signed-off-by: sumana sree <[email protected]>

---------

Signed-off-by: sumana sree <[email protected]>
  • Loading branch information
sumana-2705 authored Oct 17, 2024
1 parent c8fa39e commit f12e916
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions examples/data_types_and_io/data_types_and_io/tensorflow_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Import necessary libraries and modules

from flytekit import task, workflow
from flytekit.types.directory import TFRecordsDirectory
from flytekit.types.file import TFRecordFile

custom_image = ImageSpec(
packages=["tensorflow", "tensorflow-datasets", "flytekitplugins-kftensorflow"],
registry="ghcr.io/flyteorg",
)

if custom_image.is_container():
import tensorflow as tf

# TensorFlow Model
@task
def train_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax")]
)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
return model

@task
def evaluate_model(model: tf.keras.Model, x: tf.Tensor, y: tf.Tensor) -> float:
loss, accuracy = model.evaluate(x, y)
return accuracy

@workflow
def training_workflow(x: tf.Tensor, y: tf.Tensor) -> float:
model = train_model()
return evaluate_model(model=model, x=x, y=y)

# TFRecord Files
@task
def process_tfrecord(file: TFRecordFile) -> int:
count = 0
for record in tf.data.TFRecordDataset(file):
count += 1
return count

@workflow
def tfrecord_workflow(file: TFRecordFile) -> int:
return process_tfrecord(file=file)

# TFRecord Directories
@task
def process_tfrecords_dir(dir: TFRecordsDirectory) -> int:
count = 0
for record in tf.data.TFRecordDataset(dir.path):
count += 1
return count

@workflow
def tfrecords_dir_workflow(dir: TFRecordsDirectory) -> int:
return process_tfrecords_dir(dir=dir)

0 comments on commit f12e916

Please sign in to comment.