Skip to content

Commit

Permalink
Merge pull request eriklindernoren#629 from eriklindernoren/feature/u…
Browse files Browse the repository at this point in the history
…se_cfg_hyperparams

Use cfg hyperparams
  • Loading branch information
Flova authored Mar 24, 2021
2 parents a1fad47 + 1b76270 commit 369b459
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
27 changes: 23 additions & 4 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,24 @@ def create_modules(module_defs):
Constructs module list of layer blocks from module configuration in module_defs
"""
hyperparams = module_defs.pop(0)
output_filters = [int(hyperparams["channels"])]
hyperparams.update({
'batch': int(hyperparams['batch']),
'subdivisions': int(hyperparams['subdivisions']),
'width': int(hyperparams['width']),
'height': int(hyperparams['height']),
'channels': int(hyperparams['channels']),
'momentum': float(hyperparams['momentum']),
'decay': float(hyperparams['decay']),
'learning_rate': float(hyperparams['learning_rate']),
'burn_in': int(hyperparams['burn_in']),
'max_batches': int(hyperparams['max_batches']),
'policy': hyperparams['policy'],
'lr_steps': list(zip(map(int, hyperparams["steps"].split(",")),
map(float, hyperparams["scales"].split(","))))
})
assert hyperparams["height"] == hyperparams["width"], \
"Height and width should be equal! Non square images are padded with zeros."
output_filters = [hyperparams["channels"]]
module_list = nn.ModuleList()
for module_i, module_def in enumerate(module_defs):
modules = nn.Sequential()
Expand Down Expand Up @@ -73,8 +90,9 @@ def create_modules(module_defs):
anchors = [anchors[i] for i in anchor_idxs]
num_classes = int(module_def["classes"])
img_size = int(hyperparams["height"])
ignore_thres = float(module_def["ignore_thresh"])
# Define detection layer
yolo_layer = YOLOLayer(anchors, num_classes, img_size)
yolo_layer = YOLOLayer(anchors, num_classes, ignore_thres, img_size)
modules.add_module(f"yolo_{module_i}", yolo_layer)
# Register module list and number of output filters
module_list.append(modules)
Expand All @@ -98,7 +116,7 @@ def forward(self, x):
class YOLOLayer(nn.Module):
"""Detection layer"""

def __init__(self, anchors, num_classes, img_dim=416):
def __init__(self, anchors, num_classes, ignore_thres, img_dim=416):
super(YOLOLayer, self).__init__()
self.anchors = anchors
self.num_anchors = len(anchors)
Expand All @@ -107,7 +125,7 @@ def __init__(self, anchors, num_classes, img_dim=416):
self.mse_loss = nn.MSELoss()
self.bce_loss = nn.BCELoss()
self.obj_scale = 1
self.noobj_scale = 100
self.noobj_scale = 0.5
self.metrics = {}
self.img_dim = img_dim
self.grid_size = 0 # grid size
Expand Down Expand Up @@ -252,6 +270,7 @@ def forward(self, x, targets=None):
yolo_outputs.append(x)
layer_outputs.append(x)
yolo_outputs = to_cpu(torch.cat(yolo_outputs, 1))
loss = loss / self.hyperparams['batch']
return yolo_outputs if targets is None else (loss, yolo_outputs)

def load_darknet_weights(self, weights_path):
Expand Down
36 changes: 30 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=8, help="size of each image batch")
parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file")
parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model")
Expand Down Expand Up @@ -73,14 +71,18 @@
dataset = ListDataset(train_path, multiscale=opt.multiscale_training, img_size=opt.img_size, transform=AUGMENTATION_TRANSFORMS)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
batch_size= model.hyperparams['batch'] // model.hyperparams['subdivisions'],
shuffle=True,
num_workers=opt.n_cpu,
pin_memory=True,
collate_fn=dataset.collate_fn,
)

optimizer = torch.optim.Adam(model.parameters())
optimizer = torch.optim.SGD(
model.parameters(),
lr=model.hyperparams['learning_rate'],
weight_decay=model.hyperparams['decay'],
momentum=model.hyperparams['momentum'])

metrics = [
"grid_size",
Expand Down Expand Up @@ -111,9 +113,31 @@
loss, outputs = model(imgs, targets)
loss.backward()

if batches_done % opt.gradient_accumulations == 0:
# Accumulates gradient before each step
###############
# Run optimizer
###############

if batches_done % model.hyperparams['subdivisions'] == 0:
# Adapt learning rate
# Get learning rate defined in cfg
lr = model.hyperparams['learning_rate']
if batches_done < model.hyperparams['burn_in']:
# Burn in
lr *= (batches_done / model.hyperparams['burn_in'])
else:
# Set and parse the learning rate to the steps defined in the cfg
for threshold, value in model.hyperparams['lr_steps']:
if batches_done > threshold:
lr *= value
# Log the learning rate
logger.scalar_summary("learning_rate", lr, batches_done)
# Set learning rate
for g in optimizer.param_groups:
g['lr'] = lr

# Run optimizer
optimizer.step()
# Reset gradients
optimizer.zero_grad()

# ----------------
Expand Down

0 comments on commit 369b459

Please sign in to comment.