Skip to content

Commit d20490d

Browse files
committed
fully adapt engine to persistent interprocessing
1 parent c1bfb91 commit d20490d

File tree

5 files changed

+567
-19
lines changed

5 files changed

+567
-19
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package io.bioimage.modelrunner.pytorch;
2+
3+
import java.io.IOException;
4+
import java.net.URISyntaxException;
5+
import java.util.HashMap;
6+
import java.util.LinkedHashMap;
7+
import java.util.List;
8+
import java.util.Map;
9+
import java.util.Scanner;
10+
11+
import io.bioimage.modelrunner.apposed.appose.Types;
12+
import io.bioimage.modelrunner.apposed.appose.Service.RequestType;
13+
import io.bioimage.modelrunner.apposed.appose.Service.ResponseType;
14+
15+
public class JavaWorker {
16+
17+
private static LinkedHashMap<String, Object> tasks = new LinkedHashMap<String, Object>();
18+
19+
private final String uuid;
20+
21+
private final PytorchInterface pi;
22+
23+
private boolean cancelRequested = false;
24+
25+
public static void main(String[] args) {
26+
27+
try(Scanner scanner = new Scanner(System.in)){
28+
PytorchInterface pi;
29+
try {
30+
pi = new PytorchInterface(false);
31+
} catch (IOException | URISyntaxException e) {
32+
return;
33+
}
34+
35+
while (true) {
36+
String line;
37+
try {
38+
if (!scanner.hasNextLine()) break;
39+
line = scanner.nextLine().trim();
40+
} catch (Exception e) {
41+
break;
42+
}
43+
44+
if (line.isEmpty()) break;
45+
Map<String, Object> request = Types.decode(line);
46+
String uuid = (String) request.get("task");
47+
String requestType = (String) request.get("requestType");
48+
49+
if (requestType.equals(RequestType.EXECUTE.toString())) {
50+
String script = (String) request.get("script");
51+
Map<String, Object> inputs = (Map<String, Object>) request.get("inputs");
52+
JavaWorker task = new JavaWorker(uuid, pi);
53+
tasks.put(uuid, task);
54+
task.start(script, inputs);
55+
} else if (requestType.equals(RequestType.CANCEL.toString())) {
56+
JavaWorker task = (JavaWorker) tasks.get(uuid);
57+
if (task == null) {
58+
System.err.println("No such task: " + uuid);
59+
continue;
60+
}
61+
task.cancelRequested = true;
62+
} else {
63+
break;
64+
}
65+
}
66+
}
67+
68+
}
69+
70+
private JavaWorker(String uuid, PytorchInterface pi) {
71+
this.uuid = uuid;
72+
this.pi = pi;
73+
}
74+
75+
private void executeScript(String script, Map<String, Object> inputs) {
76+
Map<String, Object> binding = new LinkedHashMap<String, Object>();
77+
binding.put("task", this);
78+
if (inputs != null)
79+
binding.putAll(binding);
80+
81+
this.reportLaunch();
82+
try {
83+
if (script.equals("loadModel")) {
84+
pi.loadModel((String) inputs.get("modelFolder"), null);
85+
} else if (script.equals("inference")) {
86+
pi.runFromShmas((List<String>) inputs.get("inputs"), (List<String>) inputs.get("outputs"));
87+
} else if (script.equals("close")) {
88+
pi.closeModel();
89+
}
90+
} catch(Exception ex) {
91+
this.fail(Types.stackTrace(ex));
92+
return;
93+
}
94+
this.reportCompletion();
95+
}
96+
97+
private void start(String script, Map<String, Object> inputs) {
98+
new Thread(() -> executeScript(script, inputs), "Appose-" + this.uuid).start();
99+
}
100+
101+
private void reportLaunch() {
102+
respond(ResponseType.LAUNCH, null);
103+
}
104+
105+
private void reportCompletion() {
106+
respond(ResponseType.COMPLETION, null);
107+
}
108+
109+
private void update(String message, Integer current, Integer maximum) {
110+
LinkedHashMap<String, Object> args = new LinkedHashMap<String, Object>();
111+
112+
if (message != null)
113+
args.put("message", message);
114+
115+
if (current != null)
116+
args.put("current", current);
117+
118+
if (maximum != null)
119+
args.put("maximum", maximum);
120+
this.respond(ResponseType.UPDATE, args);
121+
}
122+
123+
private void respond(ResponseType responseType, Map<String, Object> args) {
124+
Map<String, Object> response = new HashMap<String, Object>();
125+
response.put("task", uuid);
126+
response.put("responseType", responseType);
127+
if (args != null)
128+
response.putAll(args);
129+
try {
130+
System.out.println(Types.encode(response));
131+
System.out.flush();
132+
} catch(Exception ex) {
133+
this.fail(Types.stackTrace(ex.getCause()));
134+
}
135+
}
136+
137+
private void cancel() {
138+
this.respond(ResponseType.CANCELATION, null);
139+
}
140+
141+
private void fail(String error) {
142+
Map<String, Object> args = null;
143+
if (error != null) {
144+
args = new HashMap<String, Object>();
145+
args.put("error", error);
146+
}
147+
respond(ResponseType.FAILURE, args);
148+
}
149+
150+
}

src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java

+6-18
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,30 @@
2424
import io.bioimage.modelrunner.apposed.appose.Types;
2525
import io.bioimage.modelrunner.apposed.appose.Service.Task;
2626
import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;
27-
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
28-
import io.bioimage.modelrunner.bioimageio.download.DownloadTracker;
29-
import io.bioimage.modelrunner.bioimageio.download.DownloadTracker.TwoParameterConsumer;
3027
import io.bioimage.modelrunner.engine.DeepLearningEngineInterface;
3128
import io.bioimage.modelrunner.exceptions.LoadModelException;
3229
import io.bioimage.modelrunner.exceptions.RunModelException;
30+
import io.bioimage.modelrunner.pytorch.shm.ShmBuilder;
31+
import io.bioimage.modelrunner.pytorch.shm.TensorBuilder;
3332
import io.bioimage.modelrunner.pytorch.tensor.ImgLib2Builder;
3433
import io.bioimage.modelrunner.pytorch.tensor.NDArrayBuilder;
35-
import io.bioimage.modelrunner.pytorch.tensor.shm.NDArrayShmBuilder;
3634
import io.bioimage.modelrunner.system.PlatformDetection;
3735
import io.bioimage.modelrunner.tensor.Tensor;
3836
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
3937
import io.bioimage.modelrunner.utils.CommonUtils;
4038
import net.imglib2.RandomAccessibleInterval;
41-
import net.imglib2.img.array.ArrayImgs;
4239
import net.imglib2.type.NativeType;
4340
import net.imglib2.type.numeric.RealType;
44-
import net.imglib2.type.numeric.real.FloatType;
4541
import net.imglib2.util.Cast;
4642
import net.imglib2.util.Util;
4743

48-
import java.io.BufferedReader;
4944
import java.io.File;
5045
import java.io.IOException;
51-
import java.io.InputStreamReader;
5246
import java.io.UnsupportedEncodingException;
53-
import java.lang.reflect.Type;
54-
import java.net.MalformedURLException;
5547
import java.net.URISyntaxException;
5648
import java.net.URL;
5749
import java.net.URLDecoder;
5850
import java.nio.charset.StandardCharsets;
59-
import java.nio.file.FileAlreadyExistsException;
6051
import java.nio.file.Files;
6152
import java.nio.file.Path;
6253
import java.nio.file.Paths;
@@ -70,15 +61,12 @@
7061
import java.util.Map;
7162

7263
import com.google.gson.Gson;
73-
import com.google.gson.reflect.TypeToken;
7464

75-
import ai.djl.MalformedModelException;
76-
import ai.djl.engine.EngineException;
7765
import ai.djl.inference.Predictor;
66+
import ai.djl.ndarray.NDArray;
7867
import ai.djl.ndarray.NDList;
7968
import ai.djl.ndarray.NDManager;
8069
import ai.djl.repository.zoo.Criteria;
81-
import ai.djl.repository.zoo.ModelNotFoundException;
8270
import ai.djl.repository.zoo.ModelZoo;
8371
import ai.djl.repository.zoo.ZooModel;
8472
import ai.djl.training.util.ProgressBar;
@@ -271,15 +259,15 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
271259
}
272260
}
273261

274-
protected void runFromShmas(List<String> inputs, List<String> outputs) throws IOException {
262+
protected void runFromShmas(List<String> inputs, List<String> outputs) throws IOException, RunModelException {
275263
try (NDManager manager = NDManager.newBaseManager()) {
276264
// Create the input lists of engine tensors (NDArrays) and their
277265
// corresponding names
278266
NDList inputList = new NDList();
279267
for (String ee : inputs) {
280268
Map<String, Object> decoded = Types.decode(ee);
281269
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
282-
NDArray inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma);
270+
NDArray inT = TensorBuilder.build(shma, manager);
283271
if (PlatformDetection.isWindows()) shma.close();
284272
inputList.add(inT);
285273
}
@@ -470,7 +458,7 @@ else if (task.status == TaskStatus.CRASHED)
470458
model = null;
471459
}
472460

473-
/**
461+
/** TODO remove
474462
* Create the arguments needed to execute Pytorch in another
475463
* process with the corresponding tensors
476464
* @return the command used to call the separate process

0 commit comments

Comments
 (0)