diff --git a/src/data_morph/shapes/factory.py b/src/data_morph/shapes/factory.py index 2b2db2af..63a12c1b 100644 --- a/src/data_morph/shapes/factory.py +++ b/src/data_morph/shapes/factory.py @@ -49,6 +49,7 @@ class ShapeFactory: 'scatter': points.Scatter, 'right_parab': points.RightParabola, 'up_parab': points.UpParabola, + 'spiral': points.Spiral, 'diamond': polygons.Diamond, 'rectangle': polygons.Rectangle, 'rings': circles.Rings, diff --git a/src/data_morph/shapes/points.py b/src/data_morph/shapes/points.py index 158be2a8..697ba83b 100644 --- a/src/data_morph/shapes/points.py +++ b/src/data_morph/shapes/points.py @@ -1,6 +1,7 @@ """Shapes that are composed of points.""" import itertools +import math from numbers import Number import numpy as np @@ -304,3 +305,51 @@ def distance(self, x: Number, y: Number) -> int: Always returns 0 to allow for scattering of the points. """ return 0 + + +class Spiral(PointCollection): + """ + Class for the spiral shape. + + .. plot:: + :scale: 100 + :caption: + This shape is generated using the panda dataset. + + from data_morph.data.loader import DataLoader + from data_morph.shapes.points import Spiral + + _ = Spiral(DataLoader.load_dataset('panda')).plot() + + Parameters + ---------- + dataset : Dataset + The starting dataset to morph into other shapes. + + Notes + ----- + The formula for a spiral can be found here: + https://en.wikipedia.org/wiki/Archimedean_spiral + """ + + def __init__(self, dataset: Dataset) -> None: + xmin, xmax = dataset.data_bounds.x_bounds + ymin, ymax = dataset.data_bounds.y_bounds + + # Coordinates of centre + cx = dataset.df.x.mean() + cy = dataset.df.y.mean() + + # Max radius + radius = min(xmax - xmin, ymax - ymin) / 2 + + # Number of rotations + num_rotations = 3 + + t = np.linspace(0, 1, num=200) + + # x and y calculations for a spiral + x = (t * radius) * np.cos(2 * num_rotations * math.pi * t) + cx + y = (t * radius) * np.sin(2 * num_rotations * math.pi * t) + cy + + super().__init__(*np.stack([x, y], axis=1)) diff --git a/tests/shapes/test_points.py b/tests/shapes/test_points.py index 842d831d..dbd273a7 100644 --- a/tests/shapes/test_points.py +++ b/tests/shapes/test_points.py @@ -88,6 +88,13 @@ class TestScatter(PointsModuleTestBase): distance_test_cases = [[(20, 50), 0.0], [(30, 60), 0.0], [(-500, -150), 0.0]] +class TestSpiral(PointsModuleTestBase): + """Test the Spiral class.""" + + shape_name = 'spiral' + distance_test_cases = [[(20, 60), 0.0], [(70, 90), 50.0]] + + class ParabolaTestBase(PointsModuleTestBase): """Base test class for parabolic shapes."""