Skip to content

Commit

Permalink
Make plant positions Point3d (#169)
Browse files Browse the repository at this point in the history
* wip: make plants 3d

* editing tests, plant provider and plant

* 3d point fixing

---------

Co-authored-by: Lukas Baecker <[email protected]>
  • Loading branch information
pascalzauberzeug and LukasBaecker authored Sep 5, 2024
1 parent 60f32ae commit 916d5a2
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 38 deletions.
19 changes: 10 additions & 9 deletions field_friend/automations/implements/weeding_implement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import rosys
from nicegui import ui
from rosys.geometry import Point, Pose
from rosys.geometry import Point3d, Pose

from ...hardware import ChainAxis
from .implement import Implement
Expand Down Expand Up @@ -44,9 +44,9 @@ def __init__(self, name: str, system: 'System', persistence_key: str = 'weeding
self.start_time: Optional[float] = None
self.last_pose: Optional[Pose] = None
self.driven_distance: float = 0.0
self.crops_to_handle: dict[str, Point] = {}
self.weeds_to_handle: dict[str, Point] = {}
self.last_punches: deque[rosys.geometry.Point] = deque(maxlen=5)
self.crops_to_handle: dict[str, Point3d] = {}
self.weeds_to_handle: dict[str, Point3d] = {}
self.last_punches: deque[Point3d] = deque(maxlen=5)
self.next_punch_y_position: float = 0

rosys.on_repeat(self._update_time_and_distance, 0.1)
Expand Down Expand Up @@ -123,8 +123,8 @@ async def _check_hardware_ready(self) -> bool:

def _has_plants_to_handle(self) -> bool:
relative_crop_positions = {
c.id: self.system.odometer.prediction.relative_point(c.position)
for c in self.system.plant_provider.get_relevant_crops(self.system.odometer.prediction.point)
c.id: Point3d.from_point(self.system.odometer.prediction.relative_point(c.position))
for c in self.system.plant_provider.get_relevant_crops(self.system.odometer.prediction.point_3d())
if self.cultivated_crop is None or c.type == self.cultivated_crop
}
upcoming_crop_positions = {
Expand All @@ -136,8 +136,8 @@ def _has_plants_to_handle(self) -> bool:
self.crops_to_handle = sorted_crops

relative_weed_positions = {
w.id: self.system.odometer.prediction.relative_point(w.position)
for w in self.system.plant_provider.get_relevant_weeds(self.system.odometer.prediction.point)
w.id: Point3d.from_point(self.system.odometer.prediction.relative_point(w.position))
for w in self.system.plant_provider.get_relevant_weeds(self.system.odometer.prediction.point_3d())
if w.type in self.relevant_weeds
}
upcoming_weed_positions = {
Expand All @@ -151,7 +151,8 @@ def _has_plants_to_handle(self) -> bool:
offset = self.system.field_friend.DRILL_RADIUS + \
self.crop_safety_distance - crop_position.distance(weed_position)
if offset > 0:
safe_weed_position = weed_position.polar(offset, crop_position.direction(weed_position))
safe_weed_position = Point3d.from_point(Point3d.projection(weed_position).polar(
offset, Point3d.projection(crop_position).direction(weed_position)))
upcoming_weed_positions[weed] = safe_weed_position
self.log.info(f'Moved weed {weed} from {weed_position} to {safe_weed_position} ' +
f'by {offset} to safe {crop} at {crop_position}')
Expand Down
10 changes: 5 additions & 5 deletions field_friend/automations/implements/weeding_screw.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def __init__(self, system: 'System') -> None:
async def start_workflow(self) -> None:
await super().start_workflow()
try:
punch_position = self.system.odometer.prediction.transform(
rosys.geometry.Point(x=self.system.field_friend.WORK_X, y=self.next_punch_y_position))
punch_position = self.system.odometer.prediction.transform3d(
rosys.geometry.Point3d(x=self.system.field_friend.WORK_X, y=self.next_punch_y_position, z=0))
self.last_punches.append(punch_position)
await self.system.puncher.punch(y=self.next_punch_y_position, depth=self.weed_screw_depth)
punched_weeds = [weed.id for weed in self.system.plant_provider.get_relevant_weeds(self.system.odometer.prediction.point)
punched_weeds = [weed.id for weed in self.system.plant_provider.get_relevant_weeds(self.system.odometer.prediction.point_3d())
if weed.position.distance(punch_position) <= self.system.field_friend.DRILL_RADIUS]
for weed_id in punched_weeds:
self.system.plant_provider.remove_weed(weed_id)
Expand All @@ -52,8 +52,8 @@ async def get_stretch(self, max_distance: float) -> float:
self.log.info(f'Found {len(weeds_in_range)} weeds in range: {weeds_in_range}')
for next_weed_id, next_weed_position in weeds_in_range.items():
# next_weed_position.x += 0.01 # NOTE somehow this helps to mitigate an offset we experienced in the tests
weed_world_position = self.system.odometer.prediction.transform(next_weed_position)
crops = self.system.plant_provider.get_relevant_crops(self.system.odometer.prediction.point)
weed_world_position = self.system.odometer.prediction.transform3d(next_weed_position)
crops = self.system.plant_provider.get_relevant_crops(self.system.odometer.prediction.point_3d())
if self.cultivated_crop and not any(c.position.distance(weed_world_position) < self.max_crop_distance for c in crops):
self.log.info('Skipping weed because it is to far from the cultivated crops')
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def update_target(self) -> None:
self.target = self.odometer.prediction.transform(rosys.geometry.Point(x=distance, y=0))

async def _drive(self, distance: float) -> None:
row = self.plant_provider.get_relevant_crops(self.odometer.prediction.point, max_distance=1.0)
row = self.plant_provider.get_relevant_crops(point=self.odometer.prediction.point_3d(), max_distance=1.0)
if len(row) >= 3:
points_array = np.array([(p.position.x, p.position.y) for p in row])
# Fit a line using least squares
Expand Down
15 changes: 8 additions & 7 deletions field_friend/automations/plant.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
import uuid
from collections import deque
from dataclasses import dataclass, field
from typing import Optional
from uuid import uuid4

from rosys.geometry import Point
from rosys.geometry import Point3d
from rosys.vision import Image


@dataclass(slots=True, kw_only=True)
class Plant:
id: str = field(default_factory=lambda: str(uuid4()))
type: str
positions: deque[Point] = field(default_factory=lambda: deque(maxlen=20))
positions: deque[Point3d] = field(default_factory=lambda: deque(maxlen=20))
detection_time: float
confidences: deque[float] = field(default_factory=lambda: deque(maxlen=20))
detection_image: Optional[Image] = None

@property
def position(self) -> Point:
def position(self) -> Point3d:
"""Calculate the middle position of all points"""
total_x = sum(point.x for point in self.positions)
total_y = sum(point.y for point in self.positions)
total_x = sum(point3d.x for point3d in self.positions)
total_y = sum(point3d.y for point3d in self.positions)
total_z = sum(point3d.z for point3d in self.positions)

middle_x = total_x / len(self.positions)
middle_y = total_y / len(self.positions)
middle_z = total_z / len(self.positions)

return Point(x=middle_x, y=middle_y)
return Point3d(x=middle_x, y=middle_y, z=middle_z)

@property
def confidence(self) -> float:
Expand Down
3 changes: 1 addition & 2 deletions field_friend/automations/plant_locator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ async def _detect_plants(self) -> None:
if world_point_3d is None:
self.log.error('could not generate world point of detection, calibration error')
continue
world_point = world_point_3d.projection()
plant = Plant(type=d.category_name,
detection_time=rosys.time(),
detection_image=new_image)
plant.positions.append(world_point)
plant.positions.append(world_point_3d)
plant.confidences.append(d.confidence)
if d.category_name in self.weed_category_names and d.confidence >= self.minimum_weed_confidence:
# self.log.info('weed found')
Expand Down
14 changes: 7 additions & 7 deletions field_friend/automations/plant_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import rosys
from nicegui import ui
from rosys.geometry import Point
from rosys.geometry import Point3d

from .plant import Plant

Expand Down Expand Up @@ -129,18 +129,18 @@ def _add_crop_prediction(self, plant: Plant) -> None:
crop_1 = sorted_crops[0]
crop_2 = sorted_crops[1]

yaw = crop_2.position.direction(crop_1.position)
prediction = crop_1.position.polar(self.crop_spacing, yaw)
yaw = crop_2.position.projection().direction(crop_1.position.projection())
prediction = crop_1.position.projection().polar(self.crop_spacing, yaw)

if plant.position.distance(prediction) > self.match_distance:
if plant.position.projection().distance(prediction) > self.match_distance:
return
plant.positions.append(prediction)
plant.positions.append(Point3d.from_point(prediction, 0))
plant.confidences.append(self.prediction_confidence)

def get_relevant_crops(self, point: Point, *, max_distance=0.5) -> list[Plant]:
def get_relevant_crops(self, point: Point3d, *, max_distance=0.5) -> list[Plant]:
return [c for c in self.crops if c.position.distance(point) <= max_distance and c.confidence >= self.minimum_combined_crop_confidence]

def get_relevant_weeds(self, point: Point, *, max_distance=0.5) -> list[Plant]:
def get_relevant_weeds(self, point: Point3d, *, max_distance=0.5) -> list[Plant]:
return [w for w in self.weeds if w.position.distance(point) <= max_distance and w.confidence >= self.minimum_combined_weed_confidence]

def settings_ui(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def test_follow_crops_no_direction(system: System, detector: rosys.vision.
assert system.automator.is_running
await forward(until=lambda: not system.automator.is_running, timeout=300)
assert not system.automator.is_running, 'automation should stop if no crops are detected anymore'
assert system.odometer.prediction.distance(rosys.geometry.Point(x=0, y=0)) == pytest.approx(2.0, abs=0.1)
assert system.odometer.prediction.distance(rosys.geometry.Point3d(x=0, y=0, z=0)) == pytest.approx(2.0, abs=0.1)
assert system.odometer.prediction.point.x == pytest.approx(2.0, abs=0.1)
assert system.odometer.prediction.point.y == pytest.approx(0, abs=0.01)
assert system.odometer.prediction.yaw_deg == pytest.approx(0, abs=1.0)
Expand All @@ -139,7 +139,7 @@ async def test_follow_crops_empty(system: System, detector: rosys.vision.Detecto
async def test_follow_crops_straight(system: System, detector: rosys.vision.DetectorSimulation):
for i in range(10):
x = i/10
p = rosys.geometry.Point3d(x=x, y=0, z=0)
p = rosys.geometry.Point3d(x=x)
detector.simulated_objects.append(rosys.vision.SimulatedObject(category_name='maize', position=p))
system.current_navigation = system.follow_crops_navigation
assert isinstance(system.current_navigation.implement, Recorder)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_plant_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ def test_extracting_relevant_crops():
plants = PlantProvider()
for i in range(20):
plants.add_crop(create_crop(i/10.0, 0))
crops = plants.get_relevant_crops(rosys.geometry.Point(x=1.0, y=0), max_distance=0.45)
crops = plants.get_relevant_crops(rosys.geometry.Point3d(x=1.0, y=0, z=0), max_distance=0.45)
assert len(crops) == 9
# TODO do not clear list; better to use weighted average in confidence property
plants.crops[10].confidences.clear()
plants.crops[10].confidences.append(0.4)
crops = plants.get_relevant_crops(rosys.geometry.Point(x=1.0, y=0), max_distance=0.45)
crops = plants.get_relevant_crops(rosys.geometry.Point3d(x=1.0, y=0, z=0), max_distance=0.45)
assert len(crops) == 8, 'crops with a confidence of less than PlantProvider.MINIMUM_COMBINED_CROP_CONFIDENCE should be ignored'


def create_crop(x: float, y: float) -> Plant:
"""Creates a maize plant with three observed positions at the given coordinates."""
plant = Plant(type='maize', detection_time=rosys.time())
for _ in range(3):
plant.positions.append(rosys.geometry.Point(x=x, y=y))
plant.positions.append(rosys.geometry.Point3d(x=x, y=y, z=0))
plant.confidences.append(0.9)
return plant

Expand All @@ -32,7 +32,7 @@ def test_crop_prediction():

def add_crop(x):
plant = Plant(type='maize', detection_time=rosys.time())
plant.positions.append(rosys.geometry.Point(x=x, y=0))
plant.positions.append(rosys.geometry.Point3d(x=x, y=0, z=0))
plant.confidences.append(confidence)
plant_provider.add_crop(plant)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ async def test_locating_of_plants(system: System, detector: rosys.vision.Detecto
await forward(20)
assert len(system.plant_provider.crops) == 1
assert system.plant_provider.crops[0].type == 'sugar_beet'
assert_point(system.plant_provider.crops[0].position, rosys.geometry.Point(x=0.212, y=0.03))
assert_point(system.plant_provider.crops[0].position, rosys.geometry.Point3d(x=0.212, y=0.03, z=0))

0 comments on commit 916d5a2

Please sign in to comment.