Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add spiral target shape #192

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/data_morph/shapes/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class ShapeFactory:
'scatter': points.Scatter,
'right_parab': points.RightParabola,
'up_parab': points.UpParabola,
'spiral': points.Spiral,
'diamond': polygons.Diamond,
'rectangle': polygons.Rectangle,
'rings': circles.Rings,
Expand Down
49 changes: 49 additions & 0 deletions src/data_morph/shapes/points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Shapes that are composed of points."""

import itertools
import math
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's stick with numpy for this.

Suggested change
import math

from numbers import Number

import numpy as np
Expand Down Expand Up @@ -304,3 +305,51 @@ def distance(self, x: Number, y: Number) -> int:
Always returns 0 to allow for scattering of the points.
"""
return 0


class Spiral(PointCollection):
"""
Class for the spiral shape.

.. plot::
:scale: 100
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep all the plots the same size:

Suggested change
:scale: 100
:scale: 75

:caption:
This shape is generated using the panda dataset.

from data_morph.data.loader import DataLoader
from data_morph.shapes.points import Spiral

_ = Spiral(DataLoader.load_dataset('panda')).plot()

Parameters
----------
dataset : Dataset
The starting dataset to morph into other shapes.

Notes
-----
The formula for a spiral can be found here:
https://en.wikipedia.org/wiki/Archimedean_spiral
Comment on lines +329 to +332
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thank you!

"""

def __init__(self, dataset: Dataset) -> None:
xmin, xmax = dataset.data_bounds.x_bounds
ymin, ymax = dataset.data_bounds.y_bounds

# Coordinates of centre
cx = dataset.df.x.mean()
cy = dataset.df.y.mean()

# Max radius
radius = min(xmax - xmin, ymax - ymin) / 2

# Number of rotations
num_rotations = 3

t = np.linspace(0, 1, num=200)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment for this one too?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two other thoughts here:

  1. Can we use fewer points? This would speed up the distance calculations.
  2. The points are farther apart as the spiral grows. Can we reduce the density at the center and fill in the outer parts better?


# x and y calculations for a spiral
x = (t * radius) * np.cos(2 * num_rotations * math.pi * t) + cx
y = (t * radius) * np.sin(2 * num_rotations * math.pi * t) + cy
Comment on lines +352 to +353
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we modify this to start at the top when the y range is bigger than the x range? For example, on the dog shape, the outermost point of the spiral would be at the dog's head instead of way off to the right.

Screenshot 2024-07-16 at 6 53 12 PM

Comment on lines +352 to +353
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x = (t * radius) * np.cos(2 * num_rotations * math.pi * t) + cx
y = (t * radius) * np.sin(2 * num_rotations * math.pi * t) + cy
x = (t * radius) * np.cos(2 * num_rotations * np.pi * t) + cx
y = (t * radius) * np.sin(2 * num_rotations * np.pi * t) + cy


super().__init__(*np.stack([x, y], axis=1))
7 changes: 7 additions & 0 deletions tests/shapes/test_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ class TestScatter(PointsModuleTestBase):
distance_test_cases = [[(20, 50), 0.0], [(30, 60), 0.0], [(-500, -150), 0.0]]


class TestSpiral(PointsModuleTestBase):
"""Test the Spiral class."""

shape_name = 'spiral'
distance_test_cases = [[(20, 60), 0.0], [(70, 90), 50.0]]


class ParabolaTestBase(PointsModuleTestBase):
"""Base test class for parabolic shapes."""

Expand Down
Loading