Skip to content

Commit 6ec2be0

Browse files
committed
feat(tensorboard): add temperature logging to tensorboard
1 parent d84b9c5 commit 6ec2be0

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

.pylintrc

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ disable=import-error,
4848
broad-except,
4949
redefined-builtin,
5050
trailing-newlines,
51-
dangerous-default-value
51+
dangerous-default-value,
52+
too-many-lines
5253

5354
[REPORTS]
5455

src/python/training/train.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,33 @@ def _add_params(module, prefix=""):
639639
writer.add_scalar("test/acc", acc, epoch)
640640
writer.add_scalar("test/acc5", acc5, epoch)
641641
writer.add_scalar("test/loss", loss, epoch)
642+
# calc temperature
643+
all_temps = []
644+
645+
def _get_temps(module, prefix=""):
646+
for name, p in module.named_parameters(recurse=False):
647+
if isinstance(module, (HalutConv2d, HalutLinear)):
648+
if name == "temperature":
649+
# pylint: disable=cell-var-from-loop
650+
all_temps.append(p)
651+
continue
652+
653+
for child_name, child_module in module.named_children():
654+
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
655+
_get_temps(child_module, prefix=child_prefix)
656+
657+
_get_temps(model)
658+
del _get_temps
659+
print("all_temps", all_temps)
660+
avg = 0.0
661+
length = 0
662+
for temp in all_temps:
663+
avg += temp.item()
664+
length += 1
665+
666+
avg_temp = avg / length
667+
writer.add_scalar("train/temp", avg_temp, epoch)
668+
642669
writer.flush()
643670
if model_ema:
644671
evaluate(
@@ -689,9 +716,9 @@ def _add_params(module, prefix=""):
689716
)
690717
# optimizer_lr_all [[0.0005], [0.0050], [0.0050], [0.0050]]
691718
optimizer_lr_local = optimizer_lr_all[0].item()
692-
if optimizer_lr_local < args.lr * 1e-4:
693-
print("learning rate too small, stop training")
694-
break
719+
# if optimizer_lr_local < args.lr * 1e-4:
720+
# print("learning rate too small, stop training")
721+
# break
695722

696723
if args.distributed:
697724
torch.distributed.barrier()

0 commit comments

Comments
 (0)