Skip to content

Commit 0e341c5

Browse files
authored
Create one_cycle() function (#1836)
1 parent 7dddb1d commit 0e341c5

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from utils.datasets import create_dataloader
2929
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
3030
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
31-
print_mutation, set_logging
31+
print_mutation, set_logging, one_cycle
3232
from utils.google_utils import attempt_download
3333
from utils.loss import compute_loss
3434
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
@@ -126,12 +126,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
126126

127127
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
128128
# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
129-
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine
129+
lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
130130
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
131131
# plot_lr_scheduler(optimizer, scheduler, epochs)
132132

133133
# Logging
134-
if wandb and wandb.run is None:
134+
if rank in [-1, 0] and wandb and wandb.run is None:
135135
opt.hyp = hyp # add hyperparameters
136136
wandb_run = wandb.init(config=opt, resume="allow",
137137
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,

utils/general.py

+5
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def clean_str(s):
102102
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
103103

104104

105+
def one_cycle(y1=0.0, y2=1.0, steps=100):
106+
# lambda function for sinusoidal ramp from y1 to y2
107+
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
108+
109+
105110
def labels_to_class_weights(labels, nc=80):
106111
# Get class weights (inverse frequency) from training labels
107112
if labels[0] is None: # no labels loaded

utils/plots.py

+1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
190190
plt.xlim(0, epochs)
191191
plt.ylim(0)
192192
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
193+
plt.close()
193194

194195

195196
def plot_test_txt(): # from utils.plots import *; plot_test()

0 commit comments

Comments
 (0)