Skip to content

Commit

Permalink
add sio eps to control detector versioning mode
Browse files Browse the repository at this point in the history
  • Loading branch information
denniswittich committed Dec 5, 2024
1 parent 6a04d2a commit cdf4435
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 120 deletions.
8 changes: 4 additions & 4 deletions learning_loop_node/data_classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from .annotations import AnnotationData, AnnotationEventType, SegmentationAnnotation, ToolOutput, UserInput
from .detections import (BoxDetection, ClassificationDetection, Detections, Observation, Point, PointDetection,
SegmentationDetection, Shape)
from .general import (AnnotationNodeStatus, Category, CategoryType, Context, DetectionStatus, ErrorConfiguration,
ModelInformation, NodeState, NodeStatus)
from .general import (AboutResponse, AnnotationNodeStatus, Category, CategoryType, Context, DetectionStatus,
ErrorConfiguration, ModelInformation, ModelVersionResponse, NodeState, NodeStatus)
from .image_metadata import ImageMetadata
from .socket_response import SocketResponse
from .training import (Errors, PretrainedModel, TrainerState, Training, TrainingError, TrainingOut, TrainingStateData,
TrainingStatus)

__all__ = [
'AnnotationData', 'AnnotationEventType', 'SegmentationAnnotation', 'ToolOutput', 'UserInput',
'AboutResponse', 'AnnotationData', 'AnnotationEventType', 'SegmentationAnnotation', 'ToolOutput', 'UserInput',
'BoxDetection', 'ClassificationDetection', 'ImageMetadata', 'Observation', 'Point', 'PointDetection',
'SegmentationDetection', 'Shape', 'Detections',
'AnnotationNodeStatus', 'Category', 'CategoryType', 'Context', 'DetectionStatus', 'ErrorConfiguration',
'ModelInformation', 'NodeState', 'NodeStatus',
'ModelInformation', 'NodeState', 'NodeStatus', 'ModelVersionResponse',
'SocketResponse',
'Errors', 'PretrainedModel', 'TrainerState', 'Training',
'TrainingError', 'TrainingOut', 'TrainingStateData', 'TrainingStatus',
Expand Down
23 changes: 23 additions & 0 deletions learning_loop_node/data_classes/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,29 @@ def from_dict(data: Dict) -> 'ModelInformation':
return from_dict(ModelInformation, data=data)


@dataclass(**KWONLY_SLOTS)
class AboutResponse:
operation_mode: str = field(metadata={"description": "The operation mode of the detector node"})
state: Optional[str] = field(metadata={
"description": "The state of the detector node",
"example": "idle, online, detecting"})
model_info: Optional[ModelInformation] = field(metadata={
"description": "Information about the model of the detector node"})
target_model: Optional[str] = field(metadata={"description": "The target model of the detector node"})
version_control: str = field(metadata={
"description": "The version control mode of the detector node",
"example": "follow_loop, specific_version, pause"})


@dataclass(**KWONLY_SLOTS)
class ModelVersionResponse:
current_version: str = field(metadata={"description": "The version of the model currently used by the detector."})
target_version: str = field(metadata={"description": "The target model version set in the detector."})
loop_version: str = field(metadata={"description": "The target model version specified by the loop."})
local_versions: List[str] = field(metadata={"description": "The locally available versions of the model."})
version_control: str = field(metadata={"description": "The version control mode."})


@dataclass(**KWONLY_SLOTS)
class ErrorConfiguration():
begin_training: Optional[bool] = False
Expand Down
107 changes: 99 additions & 8 deletions learning_loop_node/detector/detector_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import subprocess
from dataclasses import asdict
from datetime import datetime
from enum import Enum
from threading import Thread
from typing import Dict, List, Optional, Union

Expand All @@ -14,7 +15,8 @@
from fastapi.encoders import jsonable_encoder
from socketio import AsyncClient

from ..data_classes import Category, Context, DetectionStatus, ImageMetadata, ModelInformation, Shape
from ..data_classes import (AboutResponse, Category, Context, DetectionStatus, ImageMetadata, ModelInformation,
ModelVersionResponse, Shape)
from ..data_classes.socket_response import SocketResponse
from ..data_exchanger import DataExchanger, DownloadError
from ..globals import GLOBALS
Expand Down Expand Up @@ -58,8 +60,8 @@ def __init__(self, name: str, detector: DetectorLogic, uuid: Optional[str] = Non
# FollowLoop: the detector node will follow the loop and update the model if necessary
# SpecificVersion: the detector node will update to a specific version, set via the /model_version endpoint
# Pause: the detector node will not update the model
self.version_control: rest_version_control.VersionMode = rest_version_control.VersionMode.Pause if os.environ.get(
'VERSION_CONTROL_DEFAULT', 'follow_loop').lower() == 'pause' else rest_version_control.VersionMode.FollowLoop
self.version_control: VersionMode = VersionMode.Pause if os.environ.get(
'VERSION_CONTROL_DEFAULT', 'follow_loop').lower() == 'pause' else VersionMode.FollowLoop
self.target_model: Optional[ModelInformation] = None
self.loop_deployment_target: Optional[ModelInformation] = None

Expand All @@ -75,6 +77,74 @@ def __init__(self, name: str, detector: DetectorLogic, uuid: Optional[str] = Non

self.setup_sio_server()

def get_about(self) -> AboutResponse:
return AboutResponse(
operation_mode=self.operation_mode.value,
state=self.status.state,
model_info=self.detector_logic._model_info, # pylint: disable=protected-access
target_model=self.target_model.version if self.target_model else None,
version_control=self.version_control.value
)

def get_model_version_response(self) -> ModelVersionResponse:
current_version = self.detector_logic._model_info.version if self.detector_logic._model_info is not None else 'None' # pylint: disable=protected-access
target_version = self.target_model.version if self.target_model is not None else 'None'
loop_version = self.loop_deployment_target.version if self.loop_deployment_target is not None else 'None'

local_versions: list[str] = []
models_path = os.path.join(GLOBALS.data_folder, 'models')
local_models = os.listdir(models_path) if os.path.exists(models_path) else []
for model in local_models:
if model.replace('.', '').isdigit():
local_versions.append(model)

return ModelVersionResponse(
current_version=current_version,
target_version=target_version,
loop_version=loop_version,
local_versions=local_versions,
version_control=self.version_control.value,
)

async def set_model_version_mode(self, version_control_mode: str) -> None:

if version_control_mode == 'follow_loop':
self.version_control = VersionMode.FollowLoop
elif version_control_mode == 'pause':
self.version_control = VersionMode.Pause
else:
self.version_control = VersionMode.SpecificVersion
if not version_control_mode or not version_control_mode.replace('.', '').isdigit():
raise Exception('Invalid version number')
target_version = version_control_mode

if self.target_model is not None and self.target_model.version == target_version:
return

# Fetch the model uuid by version from the loop
uri = f'/{self.organization}/projects/{self.project}/models'
response = await self.loop_communicator.get(uri)
if response.status_code != 200:
self.version_control = VersionMode.Pause
raise Exception('Failed to load models from learning loop')

models = response.json()['models']
models_with_target_version = [m for m in models if m['version'] == target_version]
if len(models_with_target_version) == 0:
self.version_control = VersionMode.Pause
raise Exception(f'No Model with version {target_version}')
if len(models_with_target_version) > 1:
self.version_control = VersionMode.Pause
raise Exception(f'Multiple models with version {target_version}')

model_id = models_with_target_version[0]['id']
model_host = models_with_target_version[0].get('host', 'unknown')

self.target_model = ModelInformation(organization=self.organization, project=self.project,
host=model_host, categories=[],
id=model_id,
version=target_version)

async def soft_reload(self) -> None:
# simulate init
self.organization = environment_reader.organization()
Expand All @@ -85,8 +155,8 @@ async def soft_reload(self) -> None:
Context(organization=self.organization, project=self.project),
self.loop_communicator)
self.relevance_filter = RelevanceFilter(self.outbox)
self.version_control = rest_version_control.VersionMode.Pause if os.environ.get(
'VERSION_CONTROL_DEFAULT', 'follow_loop').lower() == 'pause' else rest_version_control.VersionMode.FollowLoop
self.version_control = VersionMode.Pause if os.environ.get(
'VERSION_CONTROL_DEFAULT', 'follow_loop').lower() == 'pause' else VersionMode.FollowLoop
self.target_model = None
# self.setup_sio_server()

Expand Down Expand Up @@ -141,9 +211,8 @@ def setup_sio_server(self) -> None:
@self.sio.event
async def detect(sid, data: Dict) -> Dict:
try:
np_image = np.frombuffer(data['image'], np.uint8)
det = await self.get_detections(
raw_image=np_image,
raw_image=np.frombuffer(data['image'], np.uint8),
camera_id=data.get('camera-id', None) or data.get('mac', None),
tags=data.get('tags', []),
source=data.get('source', None),
Expand All @@ -165,6 +234,22 @@ async def info(sid) -> Union[str, Dict]:
return asdict(self.detector_logic.model_info)
return 'No model loaded'

@self.sio.event
async def about(sid) -> Dict:
return asdict(self.get_about())

@self.sio.event
async def get_model_version(sid) -> Dict:
return asdict(self.get_model_version_response())

@self.sio.event
async def set_model_version_mode(sid, data: str) -> Union[Dict, str]:
try:
await self.set_model_version_mode(data)
except Exception as e:
return {'error': str(e)}
return "OK"

@self.sio.event
async def upload(sid, data: Dict) -> Optional[Dict]:
'''upload an image with detections'''
Expand Down Expand Up @@ -313,7 +398,7 @@ async def sync_status_with_learning_loop(self) -> None:
id=deployment_target_model_id,
version=deployment_target_model_version)

if (self.version_control == rest_version_control.VersionMode.FollowLoop and
if (self.version_control == VersionMode.FollowLoop and
self.target_model != self.loop_deployment_target):
old_target_model_version = self.target_model.version if self.target_model else None
self.target_model = self.loop_deployment_target
Expand Down Expand Up @@ -422,3 +507,9 @@ def fix_shape_detections(detections: ImageMetadata):
points = ','.join([str(value) for p in seg_detection.shape.points for _,
value in asdict(p).items()])
seg_detection.shape = points


class VersionMode(str, Enum):
FollowLoop = 'follow_loop' # will follow the loop
SpecificVersion = 'specific_version' # will follow the specific version
Pause = 'pause' # will pause the updates
30 changes: 3 additions & 27 deletions learning_loop_node/detector/rest/about.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@

import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from fastapi import APIRouter, Request

from ...data_classes import ModelInformation
from ...data_classes import AboutResponse

if TYPE_CHECKING:
from ..detector_node import DetectorNode
Expand All @@ -14,20 +13,6 @@
router = APIRouter()


@dataclass(**KWONLY_SLOTS)
class AboutResponse:
operation_mode: str = field(metadata={"description": "The operation mode of the detector node"})
state: Optional[str] = field(metadata={
"description": "The state of the detector node",
"example": "idle, online, detecting"})
model_info: Optional[ModelInformation] = field(metadata={
"description": "Information about the model of the detector node"})
target_model: Optional[str] = field(metadata={"description": "The target model of the detector node"})
version_control: str = field(metadata={
"description": "The version control mode of the detector node",
"example": "follow_loop, specific_version, pause"})


@router.get("/about", response_model=AboutResponse)
async def get_about(request: Request):
'''
Expand All @@ -38,13 +23,4 @@ async def get_about(request: Request):
curl http://hosturl/about
'''
app: 'DetectorNode' = request.app

response = AboutResponse(
operation_mode=app.operation_mode.value,
state=app.status.state,
model_info=app.detector_logic._model_info, # pylint: disable=protected-access
target_model=app.target_model.version if app.target_model is not None else None,
version_control=app.version_control.value
)

return response
return app.get_about()
90 changes: 9 additions & 81 deletions learning_loop_node/detector/rest/model_version_control.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@

import os
import sys
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING

from fastapi import APIRouter, HTTPException, Request

from ...data_classes import ModelInformation
from ...globals import GLOBALS
from ...data_classes import ModelVersionResponse

if TYPE_CHECKING:
from ..detector_node import DetectorNode
Expand All @@ -17,22 +13,7 @@
router = APIRouter()


class VersionMode(str, Enum):
FollowLoop = 'follow_loop' # will follow the loop
SpecificVersion = 'specific_version' # will follow the specific version
Pause = 'pause' # will pause the updates


@dataclass(**KWONLY_SLOTS)
class ModelVersionResponse:
current_version: str = field(metadata={"description": "The version of the model currently used by the detector."})
target_version: str = field(metadata={"description": "The target model version set in the detector."})
loop_version: str = field(metadata={"description": "The target model version specified by the loop."})
local_versions: List[str] = field(metadata={"description": "The locally available versions of the model."})
version_control: str = field(metadata={"description": "The version control mode."})


@router.get("/model_version")
@router.get("/model_version", name='Get model version information', response_model=ModelVersionResponse)
async def get_version(request: Request):
'''
Get information about the model version control and the current model version.
Expand All @@ -41,31 +22,11 @@ async def get_version(request: Request):
curl http://localhost/model_version
'''
# pylint: disable=protected-access

app: 'DetectorNode' = request.app
return app.get_model_version_response()

current_version = app.detector_logic._model_info.version if app.detector_logic._model_info is not None else 'None'
target_version = app.target_model.version if app.target_model is not None else 'None'
loop_version = app.loop_deployment_target.version if app.loop_deployment_target is not None else 'None'

local_versions: list[str] = []
models_path = os.path.join(GLOBALS.data_folder, 'models')
local_models = os.listdir(models_path) if os.path.exists(models_path) else []
for model in local_models:
if model.replace('.', '').isdigit():
local_versions.append(model)

response = ModelVersionResponse(
current_version=current_version,
target_version=target_version,
loop_version=loop_version,
local_versions=local_versions,
version_control=app.version_control.value,
)
return response


@router.put("/model_version")
@router.put("/model_version", name='Set model version control mode')
async def put_version(request: Request):
'''
Set the model version control mode.
Expand All @@ -77,42 +38,9 @@ async def put_version(request: Request):
'''
app: 'DetectorNode' = request.app
content = str(await request.body(), 'utf-8')

if content == 'follow_loop':
app.version_control = VersionMode.FollowLoop
elif content == 'pause':
app.version_control = VersionMode.Pause
else:
app.version_control = VersionMode.SpecificVersion
if not content or not content.replace('.', '').isdigit():
raise HTTPException(400, 'Invalid version number')
target_version = content

if app.target_model is not None and app.target_model.version == target_version:
return "OK"

# Fetch the model uuid by version from the loop
uri = f'/{app.organization}/projects/{app.project}/models'
response = await app.loop_communicator.get(uri)
if response.status_code != 200:
app.version_control = VersionMode.Pause
raise HTTPException(500, 'Failed to load models from learning loop')

models = response.json()['models']
models_with_target_version = [m for m in models if m['version'] == target_version]
if len(models_with_target_version) == 0:
app.version_control = VersionMode.Pause
raise HTTPException(400, f'No Model with version {target_version}')
if len(models_with_target_version) > 1:
app.version_control = VersionMode.Pause
raise HTTPException(500, f'Multiple models with version {target_version}')

model_id = models_with_target_version[0]['id']
model_host = models_with_target_version[0].get('host', 'unknown')

app.target_model = ModelInformation(organization=app.organization, project=app.project,
host=model_host, categories=[],
id=model_id,
version=target_version)
try:
await app.set_model_version_mode(content)
except Exception as exc:
raise HTTPException(400, str(exc)) from exc

return "OK"

0 comments on commit cdf4435

Please sign in to comment.