Skip to content

Commit 9406fd2

Browse files
committed
feat: added interpreter.wait and fixed codecov
1 parent 61eb32b commit 9406fd2

File tree

3 files changed

+23
-1
lines changed

3 files changed

+23
-1
lines changed

src/interpreter.rs

+10
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,13 @@ impl Interpreter {
380380
});
381381
Ok(())
382382
}
383+
384+
/// Wait for all output tensors to be ready after computation
385+
pub fn wait(&self, session: &crate::session::Session) {
386+
self.outputs(session).iter().for_each(|tinfo| {
387+
tinfo.raw_tensor().wait_read(true);
388+
});
389+
}
383390
}
384391

385392
#[repr(transparent)]
@@ -586,6 +593,7 @@ impl OperatorInfo<'_> {
586593
}
587594

588595
#[test]
596+
#[ignore = "This test doesn't work in CI"]
589597
fn test_run_session_with_callback_info_api() {
590598
let file = Path::new("tests/assets/realesr.mnn")
591599
.canonicalize()
@@ -603,6 +611,7 @@ fn test_run_session_with_callback_info_api() {
603611
}
604612

605613
#[test]
614+
#[ignore = "This test doesn't work in CI"]
606615
fn check_whether_sync_actually_works() {
607616
let file = Path::new("tests/assets/realesr.mnn")
608617
.canonicalize()
@@ -633,6 +642,7 @@ fn check_whether_sync_actually_works() {
633642
}
634643

635644
#[test]
645+
#[ignore = "This test doesn't work in CI"]
636646
fn try_to_drop_interpreter_before_session() {
637647
let file = Path::new("tests/assets/realesr.mnn")
638648
.canonicalize()

src/profile.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ macro_rules! profile {
99
};
1010
let elapsed = now.elapsed();
1111
#[cfg(feature = "tracing")]
12-
tracing::trace!("{}: Elapsed time: {:?}", $message, elapsed);
12+
tracing::info!("{}: Elapsed time: {:?}", $message, elapsed);
1313
result
1414
}}
1515
}

src/tensor.rs

+12
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,18 @@ impl<'r> RawTensor<'r> {
720720
self.shape().as_ref().contains(&-1)
721721
}
722722

723+
pub fn wait_read(&self, finish: bool) {
724+
unsafe {
725+
Tensor_wait(self.inner, MapType::MAP_TENSOR_READ, finish as i32);
726+
}
727+
}
728+
729+
pub fn wait_write(&self, finish: bool) {
730+
unsafe {
731+
Tensor_wait(self.inner, MapType::MAP_TENSOR_WRITE, finish as i32);
732+
}
733+
}
734+
723735
/// # Safety
724736
/// This is very unsafe do not use this unless you know what you are doing
725737
pub unsafe fn to_concrete<T: super::TensorType>(self) -> super::Tensor<T>

0 commit comments

Comments
 (0)