Skip to content

Commit 2415b65

Browse files
committed
Add a method to get all descendants of a node
1 parent 012ea1d commit 2415b65

File tree

5 files changed

+80
-35
lines changed

5 files changed

+80
-35
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ edition = "2021"
1212

1313
[dependencies]
1414
memchr = "2.2.1"
15-
pyo3 = { version = "0.17", optional = true }
15+
pyo3 = { version = "0.18", optional = true }
1616
quick-xml = "0.27"
1717
serde = { version = "1.0.104", features = ["derive"] }
1818
serde_json = "1.0.44"

README.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,12 @@ In NCBI, it only accounts for *scientific names* and not synonyms.
110110

111111
#### `tax.children(tax_id: str) -> List[TaxonomyNode]`
112112

113-
Returns all nodes below the given tax id.
113+
Returns all direct nodes below the given tax id.
114+
115+
#### `tax.descendants(tax_id: str) -> List[TaxonomyNode]`
116+
117+
Returns all nodes below the given tax id.
118+
Equivalent to running `tax.children` recursively on the initial result of `tax.children(tax_id)`.
114119

115120
#### `tax.lineage(tax_id: str) -> List[TaxonomyNode]`
116121

src/base.rs

+24
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,17 @@ impl<'t> Taxonomy<'t, &'t str> for GeneralTaxonomy {
324324
.collect()
325325
}
326326

327+
fn descendants(&'t self, tax_id: &'t str) -> TaxonomyResult<Vec<&'t str>> {
328+
let children: HashSet<&str> = self
329+
.traverse(tax_id)?
330+
.map(|(n, _)| n)
331+
.filter(|t| *t != tax_id)
332+
.collect();
333+
let mut children: Vec<&str> = children.into_iter().collect();
334+
children.sort_unstable();
335+
Ok(children)
336+
}
337+
327338
fn parent(&'t self, tax_id: &str) -> TaxonomyResult<Option<(&'t str, f32)>> {
328339
let idx = self.to_internal_index(tax_id)?;
329340
if idx == 0 {
@@ -377,6 +388,15 @@ impl<'t> Taxonomy<'t, InternalIndex> for GeneralTaxonomy {
377388
}
378389
}
379390

391+
fn descendants(&'t self, tax_id: InternalIndex) -> TaxonomyResult<Vec<InternalIndex>> {
392+
let children: HashSet<InternalIndex> = self
393+
.traverse(tax_id)?
394+
.map(|(n, _)| n)
395+
.filter(|n| *n != tax_id)
396+
.collect();
397+
Ok(children.into_iter().collect())
398+
}
399+
380400
fn parent(&'t self, idx: InternalIndex) -> TaxonomyResult<Option<(InternalIndex, f32)>> {
381401
if idx == 0 {
382402
return Ok(None);
@@ -452,6 +472,10 @@ mod tests {
452472
let tax = create_test_taxonomy();
453473
assert_eq!(Taxonomy::<&str>::len(&tax), 6);
454474
assert_eq!(tax.children("1").unwrap(), vec!["2", "1000", "101", "102"]);
475+
assert_eq!(
476+
tax.descendants("1").unwrap(),
477+
vec!["1000", "101", "102", "2", "562"]
478+
);
455479
assert_eq!(tax.name("562").unwrap(), "Escherichia coli");
456480
assert_eq!(tax.rank("562").unwrap(), TaxRank::Species);
457481
assert_eq!(tax.parent("562").unwrap(), Some(("2", 1.0)));

src/python.rs

+25-32
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,13 @@ impl Taxonomy {
290290
///
291291
/// Find a node by its name, Raises an exception if not found.
292292
fn find_all_by_name(&self, name: &str) -> PyResult<Vec<TaxonomyNode>> {
293-
Ok(self
293+
let res = self
294294
.tax
295295
.find_all_by_name(name)
296-
.iter()
297-
.map(|tax_id| self.as_node(tax_id).unwrap())
298-
.collect::<Vec<TaxonomyNode>>())
296+
.into_iter()
297+
.map(|tax_id| self.as_node(tax_id))
298+
.collect::<PyResult<Vec<TaxonomyNode>>>()?;
299+
Ok(res)
299300
}
300301

301302
/// parent_with_distance(self, tax_id: str, /, at_rank: str)
@@ -345,22 +346,24 @@ impl Taxonomy {
345346
/// children(self, tax_id: str)
346347
/// --
347348
///
348-
/// Return a list of child taxonomy nodes from the node id provided.
349+
/// Return a list of direct child taxonomy nodes from the node id provided.
349350
fn children(&self, tax_id: &str) -> PyResult<Vec<TaxonomyNode>> {
350-
let children: Vec<&str> = py_try!(self.tax.children(tax_id));
351-
let mut res = Vec::with_capacity(children.len());
352-
for key in children {
353-
let child = self.node(key);
354-
if let Some(c) = child {
355-
res.push(c);
356-
} else {
357-
return Err(PyErr::new::<TaxonomyError, _>(format!(
358-
"Node {} is missing in children",
359-
key
360-
)));
361-
}
362-
}
351+
let res = py_try!(self.tax.children(tax_id))
352+
.into_iter()
353+
.map(|tax_id| self.as_node(tax_id))
354+
.collect::<PyResult<Vec<TaxonomyNode>>>()?;
355+
Ok(res)
356+
}
363357

358+
/// descendants(self, tax_id: str)
359+
/// --
360+
///
361+
/// Return a list of all child taxonomy nodes from the node id provided.
362+
fn descendants(&self, tax_id: &str) -> PyResult<Vec<TaxonomyNode>> {
363+
let res = py_try!(self.tax.descendants(tax_id))
364+
.into_iter()
365+
.map(|tax_id| self.as_node(tax_id))
366+
.collect::<PyResult<Vec<TaxonomyNode>>>()?;
364367
Ok(res)
365368
}
366369

@@ -370,20 +373,10 @@ impl Taxonomy {
370373
/// Return a list of all the parent taxonomy nodes of the node id provided
371374
/// (including that node itself).
372375
fn lineage(&self, tax_id: &str) -> PyResult<Vec<TaxonomyNode>> {
373-
let lineage: Vec<&str> = py_try!(self.tax.lineage(tax_id));
374-
let mut res = Vec::with_capacity(lineage.len());
375-
for key in lineage {
376-
let ancestor = self.node(key);
377-
if let Some(a) = ancestor {
378-
res.push(a);
379-
} else {
380-
return Err(PyErr::new::<TaxonomyError, _>(format!(
381-
"Node {} is missing in lineage",
382-
key
383-
)));
384-
}
385-
}
386-
376+
let res = py_try!(self.tax.lineage(tax_id))
377+
.into_iter()
378+
.map(|tax_id| self.as_node(tax_id))
379+
.collect::<PyResult<Vec<TaxonomyNode>>>()?;
387380
Ok(res)
388381
}
389382

src/taxonomy.rs

+24-1
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@ where
1515
/// Returns the root node of the entire tree.
1616
fn root(&'t self) -> T;
1717

18-
/// Returns a [Vec] of all the child IDs of the given tax_id.
18+
/// Returns a [Vec] of all the direct children IDs of the given tax_id.
1919
fn children(&'t self, tax_id: T) -> TaxonomyResult<Vec<T>>;
2020

21+
/// Returns a [Vec] of all the children IDs of the given tax_id.
22+
fn descendants(&'t self, tax_id: T) -> TaxonomyResult<Vec<T>>;
23+
2124
/// Returns the parent of the given taxonomic node and the distance to said parent.
2225
/// The parent of the root node will return [None].
2326
fn parent(&'t self, tax_id: T) -> TaxonomyResult<Option<(T, f32)>>;
@@ -208,6 +211,17 @@ pub(crate) mod tests {
208211
})
209212
}
210213

214+
fn descendants(&'t self, tax_id: u32) -> TaxonomyResult<Vec<u32>> {
215+
let children: HashSet<u32> = self
216+
.traverse(tax_id)?
217+
.map(|(n, _)| n)
218+
.filter(|n| *n != tax_id)
219+
.collect();
220+
let mut children: Vec<u32> = children.into_iter().collect();
221+
children.sort_unstable();
222+
Ok(children)
223+
}
224+
211225
fn parent(&self, tax_id: u32) -> TaxonomyResult<Option<(u32, f32)>> {
212226
Ok(match tax_id {
213227
131567 => Some((1, 1.)),
@@ -265,6 +279,15 @@ pub(crate) mod tests {
265279
assert_eq!(tax.is_empty(), false);
266280
}
267281

282+
#[test]
283+
fn test_descendants() {
284+
let tax = MockTax;
285+
assert_eq!(
286+
tax.descendants(2).unwrap(),
287+
vec![22, 1046, 1224, 1236, 53452, 56812, 61598, 62322, 135613, 135622, 765909]
288+
);
289+
}
290+
268291
#[test]
269292
fn test_lca() {
270293
let tax = MockTax;

0 commit comments

Comments
 (0)