Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Updating the evaluation script #188

Merged
merged 5 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions rl4co/models/nn/env_embeddings/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,14 @@
(-1,) if td["first_node"].dim() == 1 else (td["first_node"].size(-1), -1)
)
if td["i"][(0,) * td["i"].dim()].item() < 1: # get first item fast
context_embedding = self.W_placeholder[None, :].expand(
batch_size, self.W_placeholder.size(-1)
)
if len(td.batch_size) < 2:
context_embedding = self.W_placeholder[None, :].expand(
batch_size, self.W_placeholder.size(-1)
)
else:
context_embedding = self.W_placeholder[None, None, :].expand(

Check warning on line 126 in rl4co/models/nn/env_embeddings/context.py

View check run for this annotation

Codecov / codecov/patch

rl4co/models/nn/env_embeddings/context.py#L126

Added line #L126 was not covered by tests
batch_size, td.batch_size[1], self.W_placeholder.size(-1)
)
else:
context_embedding = gather_by_index(
embeddings,
Expand Down
107 changes: 92 additions & 15 deletions rl4co/tasks/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,19 +161,25 @@

name = "sampling"

def __init__(self, env, samples, softmax_temp=None, **kwargs):
def __init__(self, env, samples, softmax_temp=None, temperature=1.0, top_p=0.0, top_k=0, **kwargs):

Check warning on line 164 in rl4co/tasks/eval.py

View check run for this annotation

Codecov / codecov/patch

rl4co/tasks/eval.py#L164

Added line #L164 was not covered by tests
check_unused_kwargs(self, kwargs)
super().__init__(env, kwargs.get("progress", True))

self.samples = samples
self.softmax_temp = softmax_temp
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k

Check warning on line 172 in rl4co/tasks/eval.py

View check run for this annotation

Codecov / codecov/patch

rl4co/tasks/eval.py#L170-L172

Added lines #L170 - L172 were not covered by tests

def _inner(self, policy, td):
out = policy(
td.clone(),
decode_type="sampling",
num_starts=self.samples,
multistart=True,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
multisample=True,
return_actions=True,
softmax_temp=self.softmax_temp,
select_best=True,
Expand Down Expand Up @@ -331,8 +337,10 @@
max_batch_size=4096,
start_batch_size=8192,
auto_batch_size=True,
save_results=False,
save_fname="results.npz",
samples=1280,
softmax_temp=1.0,
num_augment=8,
force_dihedral_8=True,
**kwargs,
):
num_loc = getattr(env.generator, "num_loc", None)
Expand All @@ -341,28 +349,28 @@
"greedy": {"func": GreedyEval, "kwargs": {}},
"sampling": {
"func": SamplingEval,
"kwargs": {"samples": 100, "softmax_temp": 1.0},
"kwargs": {"samples": samples, "softmax_temp": softmax_temp},
},
"multistart_greedy": {
"func": GreedyMultiStartEval,
"kwargs": {"num_starts": num_loc},
},
"augment_dihedral_8": {
"func": AugmentationEval,
"kwargs": {"num_augment": 8, "force_dihedral_8": True},
"kwargs": {"num_augment": num_augment, "force_dihedral_8": force_dihedral_8},
},
"augment": {"func": AugmentationEval, "kwargs": {"num_augment": 8}},
"augment": {"func": AugmentationEval, "kwargs": {"num_augment": num_augment}},
"multistart_greedy_augment_dihedral_8": {
"func": GreedyMultiStartAugmentEval,
"kwargs": {
"num_augment": 8,
"force_dihedral_8": True,
"num_augment": num_augment,
"force_dihedral_8": force_dihedral_8,
"num_starts": num_loc,
},
},
"multistart_greedy_augment": {
"func": GreedyMultiStartAugmentEval,
"kwargs": {"num_augment": 8, "num_starts": num_loc},
"kwargs": {"num_augment": num_augment, "num_starts": num_loc},
},
}

Expand Down Expand Up @@ -397,9 +405,78 @@
# Run evaluation
retvals = eval_fn(policy, dataloader)

# Save results
if save_results:
print("Saving results to {}".format(save_fname))
np.savez(save_fname, **retvals)

return retvals


if __name__ == "__main__":
import os
import pickle
import argparse
import importlib
import torch
from rl4co.envs import get_env

parser = argparse.ArgumentParser()

# Environment
parser.add_argument("--problem", type=str, default="tsp", help="Problem to solve")
parser.add_argument("--generator_params", type=dict, default={"num_loc": 50}, help="Generator parameters for the environment")
parser.add_argument("--data_path", type=str, default="data/tsp/tsp50_test_seed1234.npz", help="Path of the test data npz file")

# Model
parser.add_argument("--model", type=str, default="AttentionModel", help="The class name of the valid model")
parser.add_argument("--ckpt_path", type=str, default="checkpoints/am-tsp50.ckpt", help="The path of the checkpoint file")
parser.add_argument("--device", type=str, default="cuda:1", help="Device to run the evaluation")

# Evaluation
parser.add_argument("--method", type=str, default="greedy", help="Evaluation method, support 'greedy', 'sampling',\
'multistart_greedy', 'augment_dihedral_8', 'augment', 'multistart_greedy_augment_dihedral_8',\
'multistart_greedy_augment'")
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
parser.add_argument("--top_p", type=float, default=0.0, help="Top-p for sampling, from 0.0 to 1.0, 0.0 means not activated")
parser.add_argument("--top_k", type=int, default=0, help="Top-k for sampling")
parser.add_argument("--save_results", type=bool, default=True, help="Whether to save the evaluation results")
parser.add_argument("--save_path", type=str, default="results", help="The root path to save the results")

parser.add_argument("--samples", type=int, default=1280, help="Number of samples for sampling method")
parser.add_argument("--softmax_temp", type=float, default=1.0, help="Temperature for softmax in the sampling method")
parser.add_argument("--num_augment", type=int, default=8, help="Number of augmentations for augmentation method")
parser.add_argument("--force_dihedral_8", type=bool, default=True, help="Force the use of 8 augmentations for augmentation method")

opts = parser.parse_args()

# Init the environment
env = get_env(opts.problem, generator_params=opts.generator_params)

# Load the test data
dataset = env.dataset(filename=opts.data_path)

# Load the model from checkpoint
model_root = importlib.import_module("rl4co.models.zoo")
model_cls = getattr(model_root, opts.model)
model = model_cls.load_from_checkpoint(opts.ckpt_path, load_baseline=False)
model = model.to(opts.device)

# Evaluate
result = evaluate_policy(
env=env,
policy=model.policy,
dataset=dataset,
method=opts.method,
temperature=opts.temperature,
top_p=opts.top_p,
top_k=opts.top_k,
samples=opts.samples,
softmax_temp=opts.softmax_temp,
num_augment=opts.num_augment,
force_dihedral_8=opts.force_dihedral_8,
)

# Save the results
if opts.save_results:
if not os.path.exists(opts.save_path):
os.makedirs(opts.save_path)
save_fname = f"{env.name}{env.generator.num_loc}-{opts.model}-{opts.method}-temp-{opts.temperature}-top_p-{opts.top_p}-top_k-{opts.top_k}.pkl"
save_path = os.path.join(opts.save_path, save_fname)
with open(save_path, "wb") as f:
pickle.dump(result, f)
45 changes: 28 additions & 17 deletions rl4co/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
mask_logits: Whether to mask logits of infeasible actions. Defaults to True.
tanh_clipping: Tanh clipping (https://arxiv.org/abs/1611.09940). Defaults to 0.
multistart: Whether to use multistart decoding. Defaults to False.
multisample: Whether to use sampling decoding. Defaults to False.
num_starts: Number of starts for multistart decoding. Defaults to None.
"""

Expand All @@ -215,6 +216,7 @@
mask_logits: bool = True,
tanh_clipping: float = 0,
multistart: bool = False,
multisample: bool = False,
num_starts: Optional[int] = None,
select_start_nodes_fn: Optional[callable] = None,
improvement_method_mode: bool = False,
Expand All @@ -228,6 +230,7 @@
self.mask_logits = mask_logits
self.tanh_clipping = tanh_clipping
self.multistart = multistart
self.multisample = multisample
self.num_starts = num_starts
self.select_start_nodes_fn = select_start_nodes_fn
self.improvement_method_mode = improvement_method_mode
Expand Down Expand Up @@ -262,9 +265,13 @@
"""Pre decoding hook. This method is called before the main decoding operation."""

# Multi-start decoding. If num_starts is None, we use the number of actions in the action mask
if self.multistart:
if self.multistart or self.multisample:
if self.num_starts is None:
self.num_starts = env.get_num_starts(td)
if self.multisample:
log.warn(

Check warning on line 272 in rl4co/utils/decoding.py

View check run for this annotation

Codecov / codecov/patch

rl4co/utils/decoding.py#L272

Added line #L272 was not covered by tests
f"num_starts is not provided for sampling, using num_starts={self.num_starts}"
)
else:
if self.num_starts is not None:
if self.num_starts >= 1:
Expand All @@ -276,25 +283,29 @@

# Multi-start decoding: first action is chosen by ad-hoc node selection
if self.num_starts >= 1:
if action is None: # if action is provided, we use it as the first action
if self.select_start_nodes_fn is not None:
action = self.select_start_nodes_fn(td, env, self.num_starts)
if self.multistart:
if action is None: # if action is provided, we use it as the first action
if self.select_start_nodes_fn is not None:
action = self.select_start_nodes_fn(td, env, self.num_starts)
else:
action = env.select_start_nodes(td, num_starts=self.num_starts)

# Expand td to batch_size * num_starts
td = batchify(td, self.num_starts)

td.set("action", action)
td = env.step(td)["next"]
# first logprobs is 0, so p = logprobs.exp() = 1
if self.store_all_logp:
logprobs = torch.zeros_like(td["action_mask"]) # [B, N]

Check warning on line 300 in rl4co/utils/decoding.py

View check run for this annotation

Codecov / codecov/patch

rl4co/utils/decoding.py#L300

Added line #L300 was not covered by tests
else:
action = env.select_start_nodes(td, num_starts=self.num_starts)
logprobs = torch.zeros_like(action, device=td.device) # [B]

# Expand td to batch_size * num_starts
td = batchify(td, self.num_starts)

td.set("action", action)
td = env.step(td)["next"]
# first logprobs is 0, so p = logprobs.exp() = 1
if self.store_all_logp:
logprobs = torch.zeros_like(td["action_mask"]) # [B, N]
self.logprobs.append(logprobs)
self.actions.append(action)
else:
logprobs = torch.zeros_like(action, device=td.device) # [B]

self.logprobs.append(logprobs)
self.actions.append(action)
# Expand td to batch_size * num_samplestarts
td = batchify(td, self.num_starts)

Check warning on line 308 in rl4co/utils/decoding.py

View check run for this annotation

Codecov / codecov/patch

rl4co/utils/decoding.py#L308

Added line #L308 was not covered by tests

return td, env, self.num_starts

Expand Down
Loading