Skip to content

Commit

Permalink
split variable base_model_uuid_or_name into model_variant and `ba…
Browse files Browse the repository at this point in the history
…se_model_uuid`

refactoring
  • Loading branch information
denniswittich committed Nov 20, 2024
1 parent 22c2411 commit 824dd02
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 37 deletions.
20 changes: 11 additions & 9 deletions learning_loop_node/data_classes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,18 @@ class Training():
categories: List[Category]
hyperparameters: dict

# model uuid to download (to continue training) | is not a uuid when training from scratch (blank or pt-name from provided_pretrained_models->name)
base_model_uuid_or_name: str

training_number: int
training_state: str
model_variant: str # from `provided_pretrained_models->name`

start_time: float = field(default_factory=time.time)

model_uuid_for_detecting: Optional[str] = None # NOTE: this is set later after the model has been uploaded
image_data: Optional[List[dict]] = None # NOTE: this is set later after the data has been downloaded
skipped_image_count: Optional[int] = None # NOTE: this is set later after the data has been downloaded
base_model_uuid: Optional[str] = None # model uuid to continue training (is loaded from loop)

# NOTE: these are set later after the model has been uploaded
image_data: Optional[List[dict]] = None
skipped_image_count: Optional[int] = None
model_uuid_for_detecting: Optional[str] = None # Model uuid to load from the loop after training and upload

@property
def training_folder_path(self) -> Path:
Expand All @@ -98,8 +99,8 @@ def generate_training(cls, project_folder: str, context: Context, data: Dict[str
raise ValueError('categories missing or not a list')
if 'training_number' not in data or not isinstance(data['training_number'], int):
raise ValueError('training_number missing or not an int')
if 'id' not in data or not isinstance(data['id'], str):
raise ValueError('id missing or not a str')
if 'model_variant' not in data or not isinstance(data['model_variant'], str):
raise ValueError('model_variant missing or not a str')

training_uuid = str(uuid4())

Expand All @@ -112,7 +113,8 @@ def generate_training(cls, project_folder: str, context: Context, data: Dict[str
categories=Category.from_list(data['categories']),
hyperparameters=data['hyperparameters'],
training_number=data['training_number'],
base_model_uuid_or_name=data['id'],
base_model_uuid=data.get('base_model_uuid', None),
model_variant=data['model_variant'],
training_state=TrainerState.Initialized.value
)

Expand Down
5 changes: 3 additions & 2 deletions learning_loop_node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ async def lifespan(self, app: FastAPI): # pylint: disable=unused-argument
pass

async def _on_startup(self):
self.log.info('received "startup" lifecycle-event')
self.log.info('received "startup" lifecycle-event - connecting to loop')
try:
await self.reconnect_to_loop()
except Exception:
self.log.warning('Could not establish sio connection to loop during startup')
self.log.info('done')
self.log.info('successfully connected to loop - calling on_startup')
await self.on_startup()
self.log.info('successfully finished on_startup')

async def _on_shutdown(self):
self.log.info('received "shutdown" lifecycle-event')
Expand Down
32 changes: 16 additions & 16 deletions learning_loop_node/tests/trainer/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ async def test_initialized_trainer_node():
node = TrainerNode(name='test', trainer_logic=trainer, uuid='NOD30000-0000-0000-0000-000000000000')
trainer._node = node
trainer._init_new_training(context=Context(organization='zauberzeug', project='demo'),
details={'categories': [],
'id': '00000000-0000-0000-0000-000000000012', # version 1.2 of demo project
'training_number': 0,
'hyperparameters': {
'resolution': 800,
'flip_rl': False,
'flip_ud': False}
})
training_config={'categories': [],
'id': '00000000-0000-0000-0000-000000000012', # version 1.2 of demo project
'training_number': 0,
'hyperparameters': {
'resolution': 800,
'flip_rl': False,
'flip_ud': False}
})
await node._on_startup()
yield node
await node._on_shutdown()
Expand All @@ -52,14 +52,14 @@ async def test_initialized_trainer():
await node._on_startup()
trainer._node = node
trainer._init_new_training(context=Context(organization='zauberzeug', project='demo'),
details={'categories': [],
'id': '00000000-0000-0000-0000-000000000012', # version 1.2 of demo project
'training_number': 0,
'hyperparameters': {
'resolution': 800,
'flip_rl': False,
'flip_ud': False}
})
training_config={'categories': [],
'id': '00000000-0000-0000-0000-000000000012', # version 1.2 of demo project
'training_number': 0,
'hyperparameters': {
'resolution': 800,
'flip_rl': False,
'flip_ud': False}
})
yield trainer
try:
await node._on_shutdown()
Expand Down
3 changes: 1 addition & 2 deletions learning_loop_node/trainer/trainer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ async def _start_training(self):
if self._can_resume():
self.start_training_task = self._resume()
else:
base_model_uuid_or_name = self.training.base_model_uuid_or_name
if not is_valid_uuid4(base_model_uuid_or_name):
if not self.training.base_model_uuid or not is_valid_uuid4(self.training.base_model_uuid):
self.start_training_task = self._start_training_from_scratch()
else:
self.start_training_task = self._start_training_from_base_model()
Expand Down
15 changes: 8 additions & 7 deletions learning_loop_node/trainer/trainer_logic_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,25 +198,26 @@ def _init_from_last_training(self) -> None:
self._active_training_io = ActiveTrainingIO(
self._training.training_folder, self.node.loop_communicator, self._training.context)

async def begin_training(self, organization: str, project: str, details: Dict) -> None:
async def begin_training(self, organization: str, project: str, training_config: Dict) -> None:
"""Called on `begin_training` event from the Learning Loop.
"""
self._init_new_training(Context(organization=organization, project=project), details)
self._init_new_training(Context(organization=organization, project=project), training_config)
self._begin_training_task()

def _begin_training_task(self) -> None:
# NOTE: Task object is used to potentially cancel the task
self.training_task = asyncio.get_event_loop().create_task(self._run())

def _init_new_training(self, context: Context, details: Dict) -> None:
def _init_new_training(self, context: Context, training_config: Dict) -> None:
"""Called on `begin_training` event from the Learning Loop.
Note that details needs the entries 'categories' and 'training_number',
Note that training_config needs the entries 'categories', 'model_variant' and 'training_number',
but also the hyperparameter entries.
'base_model_uuid' is optional if the training is continued from a previous training.
"""
project_folder = create_project_folder(context)
if not self._environment_vars.keep_old_trainings:
delete_all_training_folders(project_folder)
self._training = Training.generate_training(project_folder, context, details)
self._training = Training.generate_training(project_folder, context, training_config)

self._active_training_io = ActiveTrainingIO(
self._training.training_folder, self.node.loop_communicator, context)
Expand Down Expand Up @@ -331,7 +332,7 @@ async def _download_model(self) -> None:
"""If training is continued, the model is downloaded from the Learning Loop to the training_folder.
The downloaded model.json file is renamed to base_model.json because a new model.json will be created during training.
"""
base_model_uuid = self.training.base_model_uuid_or_name
base_model_uuid = self.training.base_model_uuid

# TODO this checks if we continue a training -> make more explicit
if not base_model_uuid or not is_valid_uuid4(base_model_uuid):
Expand Down Expand Up @@ -490,7 +491,7 @@ async def _do_detections(self) -> None:

@abstractmethod
def _get_new_best_training_state(self) -> Optional[TrainingStateData]:
"""Is called frequently by `_sync_confusion_matrix` to check if a new "best" model is availabe.
"""Is called frequently by `_sync_training` during training to check if a new "best" model is availabe.
Returns None if no new model could be found. Otherwise TrainingStateData(confusion_matrix, meta_information).
`confusion_matrix` contains a dict of all classes:
- The classes must be identified by their uuid, not their name.
Expand Down
2 changes: 1 addition & 1 deletion mock_trainer/app_code/tests/test_detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def test_all(setup_test_project1, glc: LoopCommunicator):
# await asyncio.sleep(100)

trainer._node = node
trainer._init_new_training(context=context, details=details)
trainer._init_new_training(context=context, training_config=details)
trainer.training.model_uuid_for_detecting = latest_model_id

await trainer._do_detections()
Expand Down

0 comments on commit 824dd02

Please sign in to comment.