Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add DACT and NeuOpt Improvement Models #184

Merged
merged 8 commits into from
Jun 7, 2024
Merged

Conversation

yining043
Copy link
Contributor

@yining043 yining043 commented May 31, 2024

Description

  • Add k-opt env for TSP.
  • Add DACT and NeuOpt models.
  • Enhance the support for improvement methods.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@fedebotu fedebotu self-requested a review May 31, 2024 12:44
Copy link
Member

@fedebotu fedebotu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job!!

Most comments are [Minor], they are just suggestions and not priorities right now

@@ -291,3 +291,57 @@ def __setstate__(self, state):
self.__dict__.update(state)
self.rng = torch.manual_seed(0)
self.rng.set_state(state["rng"])


class ImprovementEnvBase(RL4COEnvBase, metaclass=abc.ABCMeta):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] I'm good with the current for the moment, however this covers more 1)routing 2)euclidean space only

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep this for now though 👍

@@ -194,7 +196,7 @@ def __init__(
feedforward_hidden=feedforward_hidden,
)

assert self.env_name in ["pdp_ruin_repair"], NotImplementedError()
assert self.env_name in ["pdp_ruin_repair", "tsp_kopt"], NotImplementedError()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it necessarily have to be of this kind, i.e., not generalizable for now to new envs? (at least the base logic)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Removed already!

@@ -212,7 +212,7 @@ def forward(
td.set("action", N2S_action)

if return_embeds:
outdict["embeds"] = h_wave
outdict["embeds"] = h_wave.detach()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we detach() here? They are not used anywhere else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because the critic takes the detached values by design for my methdos

self.init_parameters()

def init_parameters(self) -> None:
for param in self.parameters():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] maybe not needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah removed


self.init_parameters()

def init_parameters(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] same comment as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is kept since we have nn.Parameters that need initialization

"""

# Encoder: get encoder output and initial embeddings from initial state
NFE, _ = self.encoder(td)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] if data is being passed around, i.e. embeddings in this case, I suggest changing naming with lowercase letters (i.e. NFE -> nfe), since uppercase is usually reserved for classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, changed already

trainer.test(model)


def test_NeuOpt():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] same minor comment as above:

  • variables, functions ...: lowecase
  • classes: uppercase (at least first letter(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, changed already

Copy link
Member

@cbhua cbhua left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! 🚀 Really good quality of code, I have few comments on it~

@fedebotu
Copy link
Member

fedebotu commented Jun 4, 2024

Awesome! No more comments, I think ready to merge @cbhua at your will

@fedebotu fedebotu merged commit 3ffd855 into ai4co:main Jun 7, 2024
12 checks passed
@fedebotu fedebotu added this to the 0.5.0 milestone Jun 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants