Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change: Drop the feature to publish from an AutoML Model #663

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,6 @@ def _init_model_source(data):
gcs_tflite_uri = data.pop('gcsTfliteUri', None)
if gcs_tflite_uri:
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
auto_ml_model = data.pop('automlModel', None)
if auto_ml_model:
return TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
return None

@property
Expand Down Expand Up @@ -603,36 +600,6 @@ def as_dict(self, for_upload=False):
return {'gcsTfliteUri': self._gcs_tflite_uri}


class TFLiteAutoMlSource(TFLiteModelSource):
"""TFLite model source representing a tflite model created with AutoML."""

def __init__(self, auto_ml_model, app=None):
self._app = app
self.auto_ml_model = auto_ml_model

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.auto_ml_model == other.auto_ml_model
return False

def __ne__(self, other):
return not self.__eq__(other)

@property
def auto_ml_model(self):
"""Resource name of the model, created by the AutoML API or Cloud console."""
return self._auto_ml_model

@auto_ml_model.setter
def auto_ml_model(self, auto_ml_model):
self._auto_ml_model = _validate_auto_ml_model(auto_ml_model)

def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
# Upload is irrelevant for auto_ml models
return {'automlModel': self._auto_ml_model}


class ListModelsPage:
"""Represents a page of models in a Firebase project.

Expand Down
62 changes: 0 additions & 62 deletions integration/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import pytest

import firebase_admin
from firebase_admin import exceptions
from firebase_admin import ml
from tests import testutils
Expand All @@ -35,12 +34,6 @@
except ImportError:
_TF_ENABLED = False

try:
from google.cloud import automl_v1
_AUTOML_ENABLED = True
except ImportError:
_AUTOML_ENABLED = False

def _random_identifier(prefix):
#pylint: disable=unused-variable
suffix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)])
Expand Down Expand Up @@ -159,14 +152,6 @@ def check_tflite_gcs_format(model, validation_error=None):
assert model.model_hash is not None


def check_tflite_automl_format(model):
assert model.validation_error is None
assert model.published is False
assert model.model_format.model_source.auto_ml_model.startswith('projects/')
# Automl models don't have validation errors since they are references
# to valid automl models.


@pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True)
def test_create_simple_model(firebase_model):
check_model(firebase_model, NAME_AND_TAGS_ARGS)
Expand Down Expand Up @@ -388,50 +373,3 @@ def test_from_saved_model(saved_model_dir):
assert created_model.validation_error is None
finally:
_clean_up_model(created_model)


# Test AutoML functionality if AutoML is enabled.
#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the
# successful test. (Test is skipped otherwise)

@pytest.fixture
def automl_model():
assert _AUTOML_ENABLED

# It takes > 20 minutes to train a model, so we expect a predefined AutoMl
# model named 'admin_sdk_integ_test1' to exist in the project, or we skip
# the test.
automl_client = automl_v1.AutoMlClient()
project_id = firebase_admin.get_app().project_id
parent = automl_client.location_path(project_id, 'us-central1')
models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1")
# Expecting exactly one. (Ok to use last one if somehow more than 1)
automl_ref = None
for model in models:
automl_ref = model.name

# Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
if automl_ref is None:
pytest.skip("No pre-existing AutoML model found. Skipping test")

source = ml.TFLiteAutoMlSource(automl_ref)
tflite_format = ml.TFLiteFormat(model_source=source)
ml_model = ml.Model(
display_name=_random_identifier('TestModel_automl_'),
tags=['test_automl'],
model_format=tflite_format)
model = ml.create_model(model=ml_model)
yield model
_clean_up_model(model)

@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.')
def test_automl_model(automl_model):
# This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1'
automl_model.wait_for_unlocked()

check_model(automl_model, {
'display_name': automl_model.display_name,
'tags': ['test_automl'],
})
check_tflite_automl_format(automl_model)
63 changes: 0 additions & 63 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,6 @@
}
TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2)

AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263'
AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME)
TFLITE_FORMAT_JSON_3 = {
'automlModel': AUTOML_MODEL_NAME,
'sizeBytes': '3456789'
}
TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3)

AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222'
AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2}
AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2)

CREATED_UPDATED_MODEL_JSON_1 = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_1,
Expand Down Expand Up @@ -417,14 +405,6 @@ def test_model_keyword_based_creation_and_setters(self):
'tfliteModel': TFLITE_FORMAT_JSON_2
}

model.model_format = TFLITE_FORMAT_3
assert model.as_dict() == {
'displayName': DISPLAY_NAME_2,
'tags': TAGS_2,
'tfliteModel': TFLITE_FORMAT_JSON_3
}


def test_gcs_tflite_model_format_source_creation(self):
model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
model_format = ml.TFLiteFormat(model_source=model_source)
Expand All @@ -436,17 +416,6 @@ def test_gcs_tflite_model_format_source_creation(self):
}
}

def test_auto_ml_tflite_model_format_source_creation(self):
model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME)
model_format = ml.TFLiteFormat(model_source=model_source)
model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
assert model.as_dict() == {
'displayName': DISPLAY_NAME_1,
'tfliteModel': {
'automlModel': AUTOML_MODEL_NAME
}
}

def test_source_creation_from_tflite_file(self):
model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(
"my_model.tflite", "my_bucket")
Expand All @@ -460,13 +429,6 @@ def test_gcs_tflite_model_source_setters(self):
assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2
assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2

def test_auto_ml_tflite_model_source_setters(self):
model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME)
model_source.auto_ml_model = AUTOML_MODEL_NAME_2
assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2
assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2


def test_model_format_setters(self):
model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE)
model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2
Expand All @@ -477,14 +439,6 @@ def test_model_format_setters(self):
}
}

model_format.model_source = AUTOML_MODEL_SOURCE
assert model_format.model_source == AUTOML_MODEL_SOURCE
assert model_format.as_dict() == {
'tfliteModel': {
'automlModel': AUTOML_MODEL_NAME
}
}

def test_model_as_dict_for_upload(self):
model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
model_format = ml.TFLiteFormat(model_source=model_source)
Expand Down Expand Up @@ -570,23 +524,6 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type):
ml.TFLiteGCSModelSource(gcs_tflite_uri=uri)
check_error(excinfo, exc_type)

@pytest.mark.parametrize('auto_ml_model, exc_type', [
(123, TypeError),
('abc', ValueError),
('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError),
('projects/123546/models/ICN123456', ValueError),
('projects//locations/us-central1/models/ICN123456', ValueError),
('projects/123456/locations//models/ICN123456', ValueError),
('projects/123456/locations/us-central1/models/', ValueError),
('projects/ABC/locations/us-central1/models/ICN123456', ValueError),
('projects/123456/locations/us-central1/models/@#$%^&', ValueError),
('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError),
])
def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type):
with pytest.raises(exc_type) as excinfo:
ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
check_error(excinfo, exc_type)

def test_wait_for_unlocked_not_locked(self):
model = ml.Model(display_name="not_locked")
model.wait_for_unlocked()
Expand Down