Skip to content

Commit

Permalink
Merge pull request #11 from zauberzeug/improved-progress-info
Browse files Browse the repository at this point in the history
send progress for downloading data and runnign detections
  • Loading branch information
denniswittich authored Jan 31, 2024
2 parents 8be94bf + 412dd01 commit fffd8ba
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 33 deletions.
14 changes: 11 additions & 3 deletions learning_loop_node/data_exchanger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DataExchanger():
def __init__(self, context: Optional[Context], loop_communicator: LoopCommunicator):
self.context = context
self.loop_communicator = loop_communicator
self.progress = 0.0

def set_context(self, context: Context):
self.context = context
Expand Down Expand Up @@ -96,10 +97,13 @@ def jepeg_check_info(self):

async def _download_images_data(self, organization: str, project: str, image_ids: List[str], chunk_size: int = 100) -> List[Dict]:
logging.info('fetching annotations and other image data')
num_image_ids = len(image_ids)
self.jepeg_check_info()
images_data = []
starttime = time.time()
for i in tqdm(range(0, len(image_ids), chunk_size), position=0, leave=True):
progress_factor = 0.5 / num_image_ids # 50% of progress is for downloading data
for i in tqdm(range(0, num_image_ids, chunk_size), position=0, leave=True):
self.progress = i * progress_factor
chunk_ids = image_ids[i:i+chunk_size]
response = await self.loop_communicator.get(f'/{organization}/projects/{project}/images?ids={",".join(chunk_ids)}')
if response.status_code != 200:
Expand All @@ -116,13 +120,17 @@ async def _download_images_data(self, organization: str, project: str, image_ids
return images_data

async def _download_images(self, paths: List[str], image_ids: List[str], image_folder: str, chunk_size: int = 10) -> None:
if len(image_ids) == 0:
num_image_ids = len(image_ids)
if num_image_ids == 0:
logging.debug('got empty list. No images were downloaded')
return
logging.info('fetching image files')
starttime = time.time()
os.makedirs(image_folder, exist_ok=True)
for i in tqdm(range(0, len(image_ids), chunk_size), position=0, leave=True):

progress_factor = 0.5 / num_image_ids # second 50% of progress is for downloading images
for i in tqdm(range(0, num_image_ids, chunk_size), position=0, leave=True):
self.progress = 0.5 + i * progress_factor
chunk_paths = paths[i:i+chunk_size]
chunk_ids = image_ids[i:i+chunk_size]
tasks = []
Expand Down
5 changes: 2 additions & 3 deletions learning_loop_node/trainer/tests/testing_trainer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import time
from typing import Dict, List, Optional, Union

from learning_loop_node.data_classes import (BasicModel, Context, Detections,
ModelInformation, PretrainedModel)
from learning_loop_node.data_classes import BasicModel, Context, Detections, ModelInformation, PretrainedModel
from learning_loop_node.trainer.trainer_logic import TrainerLogic


Expand All @@ -17,7 +16,7 @@ def __init__(self, can_resume: bool = False) -> None:
self.error_msg: Optional[str] = None

@property
def progress(self) -> float:
def training_progress(self) -> float:
return 1.0

@property
Expand Down
55 changes: 38 additions & 17 deletions learning_loop_node/trainer/trainer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
from fastapi.encoders import jsonable_encoder
from tqdm import tqdm

from ..data_classes import (BasicModel, Category, Context, Detections, Errors,
Hyperparameter, ModelInformation, PretrainedModel,
Training, TrainingData, TrainingError,
TrainingState)
from ..data_classes import (BasicModel, Category, Context, Detections, Errors, Hyperparameter, ModelInformation,
PretrainedModel, Training, TrainingData, TrainingError, TrainingState)
from ..helpers.misc import create_image_folder
from ..node import Node
from . import training_syncronizer
Expand Down Expand Up @@ -50,6 +48,7 @@ def __init__(self, model_format: str) -> None:
self.start_training_task: Optional[Coroutine] = None
self.errors = Errors()
self.shutdown_event: asyncio.Event = asyncio.Event()
self.detection_progress = 0.0

self._training: Optional[Training] = None
self._active_training_io: Optional[ActiveTrainingIO] = None
Expand Down Expand Up @@ -145,21 +144,21 @@ async def _run(self) -> None:
tstate = self.training.training_state
logging.info(f'STATE LOOP: {tstate}')
await asyncio.sleep(0.6) # Note: Required for pytests!
if tstate == TrainingState.Initialized:
if tstate == TrainingState.Initialized: # -> DataDownloading -> DataDownloaded
await self.prepare()
elif tstate == TrainingState.DataDownloaded:
elif tstate == TrainingState.DataDownloaded: # -> TrainModelDownloading -> TrainModelDownloaded
await self.download_model()
elif tstate == TrainingState.TrainModelDownloaded:
elif tstate == TrainingState.TrainModelDownloaded: # -> TrainingRunning -> TrainingFinished
await self.train()
elif tstate == TrainingState.TrainingFinished:
elif tstate == TrainingState.TrainingFinished: # -> ConfusionMatrixSyncing -> ConfusionMatrixSynced
await self.ensure_confusion_matrix_synced()
elif tstate == TrainingState.ConfusionMatrixSynced:
elif tstate == TrainingState.ConfusionMatrixSynced: # -> TrainModelUploading -> TrainModelUploaded
await self.upload_model()
elif tstate == TrainingState.TrainModelUploaded:
elif tstate == TrainingState.TrainModelUploaded: # -> Detecting -> Detected
await self.do_detections()
elif tstate == TrainingState.Detected:
elif tstate == TrainingState.Detected: # -> DetectionUploading -> ReadyForCleanup
await self.upload_detections()
elif tstate == TrainingState.ReadyForCleanup:
elif tstate == TrainingState.ReadyForCleanup: # -> RESTART or TrainingFinished
await self.clear_training()
self.may_restart()

Expand Down Expand Up @@ -424,19 +423,23 @@ async def _do_detections(self) -> None:
image_folder = create_image_folder(project_folder)
self.node.data_exchanger.set_context(context)
image_ids = []
for state in ['inbox', 'annotate', 'review', 'complete']:
for state, p in zip(['inbox', 'annotate', 'review', 'complete'], [0.1, 0.2, 0.3, 0.4]):
self.detection_progress = p
logging.info(f'fetching image ids of {state}')
new_ids = await self.node.data_exchanger.fetch_image_ids(query_params=f'state={state}')
image_ids += new_ids
logging.info(f'downloading {len(new_ids)} images')
await self.node.data_exchanger.download_images(new_ids, image_folder)

images = await asyncio.get_event_loop().run_in_executor(None, TrainerLogic.images_for_ids, image_ids, image_folder)
logging.info(f'running detections on {len(images)} images')
num_images = len(images)
logging.info(f'running detections on {num_images} images')
batch_size = 200
idx = 0
if not images:
self.active_training_io.save_detections([], idx)
for i in tqdm(range(0, len(images), batch_size), position=0, leave=True):
for i in tqdm(range(0, num_images, batch_size), position=0, leave=True):
self.detection_progress = 0.5 + (i/num_images)*0.5
batch_images = images[i:i+batch_size]
batch_detections = await self._detect(model_information, batch_images, tmp_folder)
self.active_training_io.save_detections(batch_detections, idx)
Expand Down Expand Up @@ -512,7 +515,7 @@ async def clear_training(self):

async def stop(self) -> None:
"""If executor is running, stop it. Else cancel training task."""
if not self._training:
if not self.is_initialized:
return
if self._executor and self._executor.is_process_running():
self.executor.stop()
Expand All @@ -539,12 +542,30 @@ def may_restart(self) -> None:
logging.info('restarting')
assert self._node is not None
self._node.restart()
else:
logging.info('not restarting')

@property
def general_progress(self) -> Optional[float]:
"""Represents the progress for different states."""
if not self.is_initialized:
return None

t_state = self.training.training_state
if t_state == TrainingState.DataDownloading:
return self.node.data_exchanger.progress
if t_state == TrainingState.TrainingRunning:
return self.training_progress
if t_state == TrainingState.Detecting:
return self.detection_progress

return None
# ---------------------------------------- ABSTRACT METHODS ----------------------------------------

@property
@abstractmethod
def progress(self) -> float:
def training_progress(self) -> Optional[float]:
"""Represents the training progress."""
raise NotImplementedError

@property
Expand Down
4 changes: 2 additions & 2 deletions learning_loop_node/trainer/trainer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, name: str, trainer_logic: TrainerLogic, uuid: Optional[str] =

@property
def progress(self) -> Union[float, None]:
return self.trainer_logic.progress if (self.trainer_logic is not None and
hasattr(self.trainer_logic, 'progress')) else None
return self.trainer_logic.general_progress if (self.trainer_logic is not None and
hasattr(self.trainer_logic, 'general_progress')) else None

@property
def training_uptime(self) -> Union[float, None]:
Expand Down
12 changes: 4 additions & 8 deletions mock_trainer/app_code/mock_trainer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@
import time
from typing import Dict, List, Optional, Union

from learning_loop_node.data_classes import (BasicModel, BoxDetection,
CategoryType,
ClassificationDetection,
Detections, ErrorConfiguration,
ModelInformation, Point,
PointDetection, PretrainedModel,
SegmentationDetection, Shape)
from learning_loop_node.data_classes import (BasicModel, BoxDetection, CategoryType, ClassificationDetection,
Detections, ErrorConfiguration, ModelInformation, Point, PointDetection,
PretrainedModel, SegmentationDetection, Shape)
from learning_loop_node.trainer.trainer_logic import TrainerLogic

from . import progress_simulator
Expand Down Expand Up @@ -111,7 +107,7 @@ def provided_pretrained_models(self) -> List[PretrainedModel]:
PretrainedModel(name='large', label='Large', description='a large model')]

@property
def progress(self) -> float:
def training_progress(self) -> float:
print(f'prog. is {self.current_iteration} / {self.max_iterations} = {self.current_iteration / self.max_iterations}')
return self.current_iteration / self.max_iterations

Expand Down

0 comments on commit fffd8ba

Please sign in to comment.