Skip to content

Commit

Permalink
undo commenting out
Browse files Browse the repository at this point in the history
  • Loading branch information
stefdoerr committed Feb 12, 2025
1 parent ffd4c89 commit 8e1384f
Showing 1 changed file with 55 additions and 55 deletions.
110 changes: 55 additions & 55 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,61 +313,61 @@ def test_neighbor_autograds(
)


# @pytest.mark.parametrize("strategy", ["brute", "cell", "shared"])
# @pytest.mark.parametrize("n_batches", [1, 2, 3, 4])
# def test_large_size(strategy, n_batches):
# device = "cuda"
# cutoff = 1.76
# loop = False
# if device == "cuda" and not torch.cuda.is_available():
# pytest.skip("CUDA not available")
# torch.manual_seed(4321)
# num_atoms = int(32000 / n_batches)
# n_atoms_per_batch = torch.ones(n_batches, dtype=torch.int64) * num_atoms
# batch = torch.repeat_interleave(
# torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch
# ).to(device)
# cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
# lbox = 45.0
# pos = torch.rand(cumsum[-1], 3, device=device) * lbox
# # Ensure there is at least one pair
# pos[0, :] = torch.zeros(3)
# pos[1, :] = torch.zeros(3)
# pos.requires_grad = True
# # Find the particle appearing in the most pairs
# max_num_neighbors = 64
# ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(
# pos, batch, loop, True, cutoff, None
# )
# ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(
# ref_neighbors, ref_distance_vecs, ref_distances
# )

# max_num_pairs = ref_neighbors.shape[1]

# # Must check without PBC since Distance does not support it
# box = None
# nl = OptimizedDistance(
# cutoff_lower=0.0,
# loop=loop,
# cutoff_upper=cutoff,
# max_num_pairs=max_num_pairs,
# strategy=strategy,
# box=box,
# return_vecs=True,
# include_transpose=True,
# resize_to_fit=True,
# )
# neighbors, distances, distance_vecs = nl(pos, batch)
# neighbors = neighbors.cpu().detach().numpy()
# distance_vecs = distance_vecs.cpu().detach().numpy()
# distances = distances.cpu().detach().numpy()
# neighbors, distance_vecs, distances = sort_neighbors(
# neighbors, distance_vecs, distances
# )
# assert np.allclose(neighbors, ref_neighbors)
# assert np.allclose(distances, ref_distances)
# assert np.allclose(distance_vecs, ref_distance_vecs)
@pytest.mark.parametrize("strategy", ["brute", "cell", "shared"])
@pytest.mark.parametrize("n_batches", [1, 2, 3, 4])
def test_large_size(strategy, n_batches):
device = "cuda"
cutoff = 1.76
loop = False
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")
torch.manual_seed(4321)
num_atoms = int(32000 / n_batches)
n_atoms_per_batch = torch.ones(n_batches, dtype=torch.int64) * num_atoms
batch = torch.repeat_interleave(
torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch
).to(device)
cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
lbox = 45.0
pos = torch.rand(cumsum[-1], 3, device=device) * lbox
# Ensure there is at least one pair
pos[0, :] = torch.zeros(3)
pos[1, :] = torch.zeros(3)
pos.requires_grad = True
# Find the particle appearing in the most pairs
max_num_neighbors = 64
ref_neighbors, ref_distance_vecs, ref_distances = compute_ref_neighbors(
pos, batch, loop, True, cutoff, None
)
ref_neighbors, ref_distance_vecs, ref_distances = sort_neighbors(
ref_neighbors, ref_distance_vecs, ref_distances
)

max_num_pairs = ref_neighbors.shape[1]

# Must check without PBC since Distance does not support it
box = None
nl = OptimizedDistance(
cutoff_lower=0.0,
loop=loop,
cutoff_upper=cutoff,
max_num_pairs=max_num_pairs,
strategy=strategy,
box=box,
return_vecs=True,
include_transpose=True,
resize_to_fit=True,
)
neighbors, distances, distance_vecs = nl(pos, batch)
neighbors = neighbors.cpu().detach().numpy()
distance_vecs = distance_vecs.cpu().detach().numpy()
distances = distances.cpu().detach().numpy()
neighbors, distance_vecs, distances = sort_neighbors(
neighbors, distance_vecs, distances
)
assert np.allclose(neighbors, ref_neighbors)
assert np.allclose(distances, ref_distances)
assert np.allclose(distance_vecs, ref_distance_vecs)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 8e1384f

Please sign in to comment.