Skip to content

Commit

Permalink
Fix for issue #734 (#735)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevin-duclos authored Feb 11, 2024
1 parent ee0cf2c commit 8883b8a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
11 changes: 8 additions & 3 deletions trackpy/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def subtract_drift(traj, drift=None, inplace=False):
Parameters
----------
traj : DataFrame of trajectories, including columns x, y, and frame
traj : DataFrame of trajectories, including columns x, y, frame,
and particle (if there is more than one particle).
drift : optional DataFrame([x, y], index=frame) like output of
compute_drift(). If no drift is passed, drift is computed from traj.
Expand All @@ -308,8 +309,12 @@ def subtract_drift(traj, drift=None, inplace=False):
drift = compute_drift(traj)
if not inplace:
traj = traj.copy()
traj.set_index('frame', inplace=True, drop=False)
traj.sort_index(inplace=True)
if 'particle' in traj.columns:
traj.set_index(['frame', 'particle'], inplace=True, drop=False)
else:
traj.set_index(['frame'], inplace=True, drop=False)
# Order of particles is irrelevant for performance
traj.sort_index(level='frame', inplace=True)
for col in drift.columns:
traj[col] = traj[col].sub(drift[col], fill_value=0, level='frame')
return traj
Expand Down
26 changes: 23 additions & 3 deletions trackpy/tests/test_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ def random_walk(N):
def conformity(df):
""" Organize toy data to look like real data. Be strict about dtypes:
particle is a float and frame is an integer."""
df['frame'] = df['frame'].astype(np.int64) # pandas maps to int32 on windows!
df['particle'] = df['particle'].astype(float)
df['frame'] = df['frame'].astype(np.int64)
df['x'] = df['x'].astype(float)
df['y'] = df['y'].astype(float)
df.set_index('frame', drop=False, inplace=True)
return pandas_sort(df, by=['frame', 'particle'])
if 'particle' in df.columns:
df['particle'] = df['particle'].astype(float)
return pandas_sort(df, by=['frame', 'particle'])
else:
return pandas_sort(df, by=['frame'])


def assert_traj_equal(t1, t2):
Expand Down Expand Up @@ -58,12 +61,19 @@ def setUp(self):
for i in range(P)]
self.many_walks = conformity(pandas_concat(particles))

self.unlabeled_walks = self.many_walks.copy()
del self.unlabeled_walks['particle']

a = DataFrame({'x': np.arange(N), 'y': np.zeros(N),
'frame': np.arange(N), 'particle': np.zeros(N)})
b = DataFrame({'x': np.arange(1, N), 'y': Y + np.zeros(N - 1),
'frame': np.arange(1, N), 'particle': np.ones(N - 1)})
self.steppers = conformity(pandas_concat([a, b]))

# Single-particle trajectory with no particle label
self.single_stepper = conformity(a.copy())
del self.single_stepper['particle']

def test_no_drift(self):
N = 10
expected = DataFrame({'x': np.zeros(N), 'y': np.zeros(N)}).iloc[1:]
Expand Down Expand Up @@ -116,6 +126,16 @@ def test_subtract_constant_drift(self):
actual = tp.subtract_drift(add_drift(self.steppers, drift), drift)
assert_traj_equal(actual, self.steppers)

actual = tp.subtract_drift(add_drift(self.single_stepper, drift), drift)
assert_traj_equal(actual, self.single_stepper)

# Test that subtract_drift is OK without particle labels.
# In principle, Series.sub() may raise an error because
# the 'frame' index is duplicated.
# Don't check the result since we can't compare unlabeled trajectories!
actual = tp.subtract_drift(add_drift(self.unlabeled_walks, drift),
drift)


class TestMSD(StrictTestCase):
def setUp(self):
Expand Down

0 comments on commit 8883b8a

Please sign in to comment.