Skip to content

Commit

Permalink
[MERGE] download script merged from paper ready
Browse files Browse the repository at this point in the history
  • Loading branch information
fennecinspace committed Nov 13, 2023
2 parents 8b07cb6 + 56b1e49 commit 003f356
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 92 deletions.
15 changes: 12 additions & 3 deletions conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,22 @@ ml:
augmentation_percent_baseline: 0
baseline: True
epochs: 300
wandb:
entity: sdcn-nantes
project: sdcn-shit-testing
sampling:
enable: false
metric: brisque # brisque (smaller is better), dbcnn (bigger is better), ilniqe (smaller is better)
sample: best # to take smaller or bigger values is decided depending the metric

wandb:
entity: sdcn-nantes
project: sdcn-shit-testing
download:
list_all: false
list_finished: true
list_running: false
sort: false
folder: [".", "models"]
download: false
query_filter: false

prompt:
template: vocabulary
Expand Down
42 changes: 21 additions & 21 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
accelerate
controlnet_aux
accelerate==0.22.0
controlnet_aux==0.0.6
diffusers==0.14.0
gdown
huggingface
hydra-core
jupyter
huggingface==0.0.1
transformers==4.33.0
mediapipe==0.10.0
matplotlib
numpy
opencv-contrib-python
pandas
pillow
pyiqa
pycocotools
tqdm
transformers
ultralytics
wandb
wget
xformers
spacy
scikit-learn
xformers==0.0.21
scikit-learn==1.3.2
pyiqa==0.1.8
spacy==3.7.2
gdown==4.7.1
hydra-core==1.3.2
jupyter==1.0.0
matplotlib==3.7.2
numpy==1.25.2
opencv-contrib-python==4.8.0.76
pandas==2.1.1
pillow==10.0.0
pycocotools==2.0.7
tqdm==4.66.1
wandb==0.15.12
wget==3.2
tabulate==0.9.0
130 changes: 62 additions & 68 deletions src/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,34 @@
# "wandb login" in a terminal before using this script
# Runs states : finished, running, crashed
# Important run params/functions : .id, .state, .files, .config, .summary, .history(), .file('filename.ext').download(), api.run(f"{ENTITY}/{PROJECT}/{run.id}")

import hydra
import os
import sys
import argparse
import pandas as pd
import wandb

from omegaconf import DictConfig
from pathlib import Path
from tabulate import tabulate
from tqdm import tqdm

import pandas as pd
from common import logger

BASE_PATH = os.path.dirname(os.path.abspath(__file__))
DEFAULT_DOWNLOAD_DIR = os.path.join(BASE_PATH, 'models')


class Downloader:
def __init__(self, entity, project):
self.entity = entity
self.project = project
self.api = wandb.Api(timeout=25)
self.runs_url = f'https://wandb.ai/{entity}/{project}/runs/'


def get_runs(self, query_filter = None):
finished, running, other = [], [], []

print("Loading Runs DB...")
logger.info("Loading Runs DB...")
if query_filter:
finished_query = self.api.runs(f"{self.entity}/{self.project}", {
"display_name": {"$regex": f".*{query_filter}.*"}
Expand All @@ -34,8 +38,8 @@ def get_runs(self, query_filter = None):
finished_query = self.api.runs(f"{self.entity}/{self.project}", per_page = 4)
finished_query[0]
l = finished_query.length
print("Loading Runs to memory...")

logger.info("Loading Runs to memory...")
for i in tqdm(range(l)):
run = finished_query[i]
if run.state == 'finished':
Expand All @@ -46,10 +50,9 @@ def get_runs(self, query_filter = None):
other += [run]
return finished, running, other


def check_runs(self, runs, query_filter):
print('Checking list of finished runs...')
finished, running, other = self.get_runs(query_filter)
logger.info('Checking list of finished runs...')
finished, _, _ = self.get_runs(query_filter)
for run in finished:
if run.id in runs:
runs[runs.index(run.id)] = run
Expand All @@ -59,36 +62,34 @@ def check_runs(self, runs, query_filter):
runs_found = []
for run in runs:
if type(run).__name__ == 'str':
print(f'Could not find {run}')
logger.info(f'Could not find {run}')
else:
runs_found += [run]
return runs_found


def download_model(self, run, download_dir = DEFAULT_DOWNLOAD_DIR):
download_location = os.path.join(download_dir, f'{run.id}.{run.name}.pt')
if os.path.exists(download_location):
print(f"Skipping : '{run.id}.{run.name}.pt' already exists")
logger.info(f"Skipping : '{run.id}.{run.name}.pt' already exists")
else:
try:
model_artifact = self.api.artifact(f'{self.entity}/{self.project}/run_{run.id}_model:best')
model_artifact.download(download_dir)
if os.path.exists(os.path.join(download_dir, 'best.pt')):
os.rename(os.path.join(download_dir, 'best.pt'), download_location)
else:
print('Skipping : Could not download.',end='\n\n')
logger.info('Skipping : Could not download.')
except Exception as e:
print(f'Could not download {run} :', e)

logger.info(f'Could not download {run}. E:{e}')

def save_summary(self, runs, download_dir = DEFAULT_DOWNLOAD_DIR, project = 'results'):
summary_list, config_list, name_list = [], [], []
for run in runs:
for run in runs:
summary_list.append(run.summary._json_dict)
config_list.append(
{k: v for k,v in run.config.items()
if not k.startswith('_')})

name_list.append(run.name)

runs_df = pd.DataFrame({
Expand All @@ -97,83 +98,76 @@ def save_summary(self, runs, download_dir = DEFAULT_DOWNLOAD_DIR, project = 'res
"name": name_list
})

print('Saving WANDB summary to:', os.path.join(download_dir, f'{project}.csv'))
runs_df.to_csv(os.path.join(download_dir, f'{project}.csv'))
save_path = os.path.join(download_dir, f'{project}.csv')
logger.info(f'Saving WANDB summary to: {save_path}')
runs_df.to_csv(save_path)


@hydra.main(version_base=None, config_path="../conf", config_name="config")
def main(cfg: DictConfig) -> None:
download_params = cfg['ml']['wandb']['download']
entity = cfg['ml']['wandb']['entity']
project = cfg['ml']['wandb']['project']
downloader = Downloader(entity, project)

def main(args):
d = Downloader(args.entity, args.project)
finished, running, other = d.get_runs(args.query_filter)
query_filter = None if not download_params['query_filter'] else download_params['query_filter']
finished, running, other = downloader.get_runs(query_filter)

runs_to_process = []
if args.runs:
runs = d.check_runs(args.runs, args.query_filter)
if 'runs' in download_params:
runs = downloader.check_runs(download_params['runs'], query_filter)
runs_to_process = runs
elif args.list_all:
elif download_params['list_all']:
for run in finished + running + other:
runs_to_process += [run]
elif args.list_finished:
elif download_params['list_finished']:
runs_to_process = finished
elif args.list_running:
elif download_params['list_running']:
runs_to_process = running

downloaded_nb = 0
downloaded_nb = 0
runs_to_print = []
runs_to_download = []

folder = os.path.join(*download_params['folder'], project)
folder = os.path.abspath(folder)
folder = Path(folder)
folder.mkdir(parents=True, exist_ok=True)
folder = str(folder)

for run in runs_to_process:
download_location = os.path.join(args.folder, f'{run.id}.{run.name}.pt')
downloaded = True
download_location = os.path.join(folder, f'{run.id}.{run.name}.pt')
downloaded = True
if os.path.exists(download_location):
downloaded = 'Yes'
downloaded_nb += 1
else:
downloaded = 'No'
runs_to_download += [run]

runs_to_print += [[run.id, run.name, run.state, downloaded, d.runs_url + run.id]]
print(tabulate(runs_to_print, headers=['ID', 'NAME', 'STATUS', 'DOWNLOADED', 'URL']))
runs_to_print += [[run.id, run.name, run.state, downloaded, downloader.runs_url + run.id]]

logger.info(tabulate(runs_to_print, headers=['ID', 'NAME', 'STATUS', 'DOWNLOADED', 'URL']))

if args.list_all:
print(f'Models: {len(finished + running + other)}, Downloaded: {downloaded_nb} (In : {args.folder}). Can not download {len(running)}, still running.')
elif args.list_finished or args.runs:
print(f'Models: {len(finished)}, Downloaded: {downloaded_nb} (In : {args.folder})')
elif args.list_running:
print(f'Models: {len(running)}. Still running. Not able to download weights yet.')
if download_params['list_all']:
logger.info(f'Models: {len(finished + running + other)}, Downloaded: {downloaded_nb} (In : {folder}). Can not download {len(running)}, still running.')
elif download_params['list_finished'] or 'runs' in download_params:
logger.info(f'Models: {len(finished)}, Downloaded: {downloaded_nb} (In : {folder})')
elif download_params['list_running']:
logger.info(f'Models: {len(running)}. Still running. Not able to download weights yet.')

if args.download and runs_to_download:
print('-------------------------')
if download_params['download'] and runs_to_download:
for run in runs_to_download:
print(f'Downloading : {run.id}.{run.name}.pt')
d.download_model(run, args.folder)
logger.info(f'Downloading : {run.id}.{run.name}.pt')
downloader.download_model(run, folder)

try:
print('Saving Runs summary !')
d.save_summary(runs_to_process, args.folder, args.project)
logger.info('Saving Runs summary !')
downloader.save_summary(runs_to_process, folder, project)
except Exception as e:
print('Could not save summary:', e)
logger.info('Could not save summary.')
logger.info(e)


if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument('-la', '--list-all', action = 'store_true', required = False, default = False, help='list all runs')
ap.add_argument('-lf', '--list-finished', action = 'store_true', required = False, default = False, help='list finished runs')
ap.add_argument('-lr', '--list-running', action = 'store_true', required = False, default = False, help='list running runs')
ap.add_argument('-r', '--runs', nargs='+', required = False, help = 'Use a custom list of runs')
ap.add_argument('-s', '--sort', action = 'store_true', required = False, default = False, help='sort listed')
ap.add_argument('-f', '--folder', type = str, required = False, default = './models', help='Folder to download & check for local runs. use with one of the listing arguments to download')
ap.add_argument('-d', '--download', action = 'store_true', required = False, default = False, help='Download listed models')
ap.add_argument('-e', '--entity', type = str, required = True, help='Wandb Entity')
ap.add_argument('-p', '--project', type = str, required = True, help='Wandb Project')
ap.add_argument('-q', '--query_filter', type = str, required = False, default = None, help='Filter by strings in run names')
args = ap.parse_args()

if args.folder:
args.folder = os.path.abspath(args.folder)

if len(sys.argv) < 2:
ap.print_help()
sys.exit(1)
else:
main(args)
main()

0 comments on commit 003f356

Please sign in to comment.