-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] Add DACT and NeuOpt Improvement Models #184
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job!!
Most comments are [Minor], they are just suggestions and not priorities right now
@@ -291,3 +291,57 @@ def __setstate__(self, state): | |||
self.__dict__.update(state) | |||
self.rng = torch.manual_seed(0) | |||
self.rng.set_state(state["rng"]) | |||
|
|||
|
|||
class ImprovementEnvBase(RL4COEnvBase, metaclass=abc.ABCMeta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] I'm good with the current for the moment, however this covers more 1)routing 2)euclidean space only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can keep this for now though 👍
rl4co/models/zoo/n2s/encoder.py
Outdated
@@ -194,7 +196,7 @@ def __init__( | |||
feedforward_hidden=feedforward_hidden, | |||
) | |||
|
|||
assert self.env_name in ["pdp_ruin_repair"], NotImplementedError() | |||
assert self.env_name in ["pdp_ruin_repair", "tsp_kopt"], NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it necessarily have to be of this kind, i.e., not generalizable for now to new envs? (at least the base logic)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Removed already!
@@ -212,7 +212,7 @@ def forward( | |||
td.set("action", N2S_action) | |||
|
|||
if return_embeds: | |||
outdict["embeds"] = h_wave | |||
outdict["embeds"] = h_wave.detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we detach()
here? They are not used anywhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because the critic takes the detached values by design for my methdos
rl4co/models/zoo/neuopt/decoder.py
Outdated
self.init_parameters() | ||
|
||
def init_parameters(self) -> None: | ||
for param in self.parameters(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] maybe not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah removed
|
||
self.init_parameters() | ||
|
||
def init_parameters(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] same comment as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is kept since we have nn.Parameters that need initialization
rl4co/models/zoo/neuopt/policy.py
Outdated
""" | ||
|
||
# Encoder: get encoder output and initial embeddings from initial state | ||
NFE, _ = self.encoder(td) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] if data is being passed around, i.e. embeddings in this case, I suggest changing naming with lowercase letters (i.e. NFE
-> nfe
), since uppercase is usually reserved for classes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, changed already
tests/test_training.py
Outdated
trainer.test(model) | ||
|
||
|
||
def test_NeuOpt(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] same minor comment as above:
- variables, functions ...: lowecase
- classes: uppercase (at least first letter(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, changed already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! 🚀 Really good quality of code, I have few comments on it~
Awesome! No more comments, I think ready to merge @cbhua at your will |
Description
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!