-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdemo.py
66 lines (55 loc) · 2.34 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
'''
Download Merlin and test the model on sample data that is downloaded from huggingface
'''
import os
import warnings
import torch
from merlin.data import download_sample_data
from merlin.data import DataLoader
from merlin import Merlin
warnings.filterwarnings("ignore")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Merlin()
model.eval()
model.cuda()
data_dir = os.path.join(os.path.dirname(__file__), "abct_data")
cache_dir = data_dir.replace("abct_data", "abct_data_cache")
datalist = [
{
"image": download_sample_data(data_dir), # function returns local path to nifti file
"text": "Lower thorax: A small low-attenuating fluid structure is noted in the right cardiophrenic angle in keeping with a tiny pericardial cyst."
"Liver and biliary tree: Normal. Gallbladder: Normal. Spleen: Normal. Pancreas: Normal. Adrenal glands: Normal. "
"Kidneys and ureters: Symmetric enhancement and excretion of the bilateral kidneys, with no striated nephrogram to suggest pyelonephritis. "
"Urothelial enhancement bilaterally, consistent with urinary tract infection. No renal/ureteral calculi. No hydronephrosis. "
"Gastrointestinal tract: Normal. Normal gas-filled appendix. Peritoneal cavity: No free fluid. "
"Bladder: Marked urothelial enhancement consistent with cystitis. Uterus and ovaries: Normal. "
"Vasculature: Patent. Lymph nodes: Normal. Abdominal wall: Normal. "
"Musculoskeletal: Degenerative change of the spine.",
},
]
dataloader = DataLoader(
datalist=datalist,
cache_dir=cache_dir,
batchsize=8,
shuffle=True,
num_workers=0,
)
for batch in dataloader:
outputs = model(
batch["image"].to(device),
batch["text"]
)
print(f"\n================== Output Shapes ==================")
print(f"Contrastive image embeddings shape: {outputs[0].shape}")
print(f"Phenotype predictions shape: {outputs[1].shape}")
print(f"Contrastive text embeddings shape: {outputs[2].shape}")
## Get the Image Embeddings
model = Merlin(ImageEmbedding=True)
model.eval()
model.cuda()
for batch in dataloader:
outputs = model(
batch["image"].to(device),
)
print(f"\n================== Output Shapes ==================")
print(f"Image embeddings shape (Can be used for downstream tasks): {outputs[0].shape}")