@@ -639,6 +639,33 @@ def _add_params(module, prefix=""):
639
639
writer .add_scalar ("test/acc" , acc , epoch )
640
640
writer .add_scalar ("test/acc5" , acc5 , epoch )
641
641
writer .add_scalar ("test/loss" , loss , epoch )
642
+ # calc temperature
643
+ all_temps = []
644
+
645
+ def _get_temps (module , prefix = "" ):
646
+ for name , p in module .named_parameters (recurse = False ):
647
+ if isinstance (module , (HalutConv2d , HalutLinear )):
648
+ if name == "temperature" :
649
+ # pylint: disable=cell-var-from-loop
650
+ all_temps .append (p )
651
+ continue
652
+
653
+ for child_name , child_module in module .named_children ():
654
+ child_prefix = f"{ prefix } .{ child_name } " if prefix != "" else child_name
655
+ _get_temps (child_module , prefix = child_prefix )
656
+
657
+ _get_temps (model )
658
+ del _get_temps
659
+ print ("all_temps" , all_temps )
660
+ avg = 0.0
661
+ length = 0
662
+ for temp in all_temps :
663
+ avg += temp .item ()
664
+ length += 1
665
+
666
+ avg_temp = avg / length
667
+ writer .add_scalar ("train/temp" , avg_temp , epoch )
668
+
642
669
writer .flush ()
643
670
if model_ema :
644
671
evaluate (
@@ -689,9 +716,9 @@ def _add_params(module, prefix=""):
689
716
)
690
717
# optimizer_lr_all [[0.0005], [0.0050], [0.0050], [0.0050]]
691
718
optimizer_lr_local = optimizer_lr_all [0 ].item ()
692
- if optimizer_lr_local < args .lr * 1e-4 :
693
- print ("learning rate too small, stop training" )
694
- break
719
+ # if optimizer_lr_local < args.lr * 1e-4:
720
+ # print("learning rate too small, stop training")
721
+ # break
695
722
696
723
if args .distributed :
697
724
torch .distributed .barrier ()
0 commit comments