Skip to content

Commit

Permalink
Ban module level imports for cv2, matplotlib and numpy (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
iurisilvio authored Jul 22, 2024
1 parent d4a72d9 commit 3620fda
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 17 deletions.
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ exclude = [
"tests/manual/debugme.py", # file is intentionally broken
]


# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

Expand All @@ -92,6 +91,13 @@ convention = "google"
# Preserve types, even if a file imports `from __future__ import annotations`.
keep-runtime-typing = true

[tool.ruff.lint.flake8-tidy-imports]
banned-module-level-imports = [
"cv2",
"matplotlib",
"numpy",
]

[tool.mypy]
python_version = "3.8"
exclude = ["^build/"]
Expand Down
13 changes: 10 additions & 3 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import json
import os
Expand All @@ -6,9 +8,8 @@
import time
import zipfile
from importlib import import_module
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

import numpy as np
import requests
import yaml
from dotenv import load_dotenv
Expand All @@ -28,7 +29,6 @@
)
from roboflow.core.dataset import Dataset
from roboflow.models.classification import ClassificationModel
from roboflow.models.inference import InferenceModel
from roboflow.models.instance_segmentation import InstanceSegmentationModel
from roboflow.models.keypoint_detection import KeypointDetectionModel
from roboflow.models.object_detection import ObjectDetectionModel
Expand All @@ -37,6 +37,11 @@
from roboflow.util.general import write_line
from roboflow.util.versions import get_wrong_dependencies_versions, print_warn_for_wrong_dependencies_versions

if TYPE_CHECKING:
import numpy as np

from roboflow.models.inference import InferenceModel

load_dotenv()


Expand Down Expand Up @@ -401,6 +406,8 @@ def live_plot(epochs, mAP, loss, title=""):
loss: Union[np.ndarray, list]

if "roboflow-train" in models.keys():
import numpy as np

# training has started
epochs = np.array([int(epoch["epoch"]) for epoch in models["roboflow-train"]["epochs"]])
mAP = np.array([float(epoch["mAP"]) for epoch in models["roboflow-train"]["epochs"]])
Expand Down
8 changes: 5 additions & 3 deletions roboflow/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import sys
from typing import Any, List

import numpy as np
import requests
from numpy import ndarray
from PIL import Image

from roboflow.adapters import rfapi
Expand Down Expand Up @@ -407,6 +405,8 @@ def active_learning(
use_localhost: (bool) = determines if local http format used or remote endpoint
local_server: (str) = local http address for inference server, use_localhost must be True for this to be used
""" # noqa: E501 // docs
import numpy as np

prediction_results = []

# ensure that all fields of conditionals have a key:value pair
Expand Down Expand Up @@ -528,7 +528,9 @@ def active_learning(

# return predictions with filenames if globbed images from dir,
# otherwise return latest prediction result
return prediction_results if type(raw_data_location) is not ndarray else prediction_results[-1]["predictions"]
return (
prediction_results if type(raw_data_location) is not np.ndarray else prediction_results[-1]["predictions"]
)

def __str__(self):
projects = self.projects()
Expand Down
6 changes: 4 additions & 2 deletions roboflow/models/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import random
import urllib

import cv2
import numpy as np
import requests
from PIL import Image

Expand Down Expand Up @@ -178,6 +176,9 @@ def predict( # type: ignore[override]
original_dimensions = None
# If image is local image
if not hosted:
import cv2
import numpy as np

if isinstance(image_path, str):
image = Image.open(image_path).convert("RGB")
dimensions = image.size
Expand Down Expand Up @@ -294,6 +295,7 @@ def webcam(
stroke (int): Stroke width for bounding box
labels (bool): Whether to show labels on bounding box
""" # noqa: E501 // docs
import cv2

os.environ["OPENCV_VIDEOIO_PRIORITY_MSMF"] = "0"

Expand Down
7 changes: 5 additions & 2 deletions roboflow/util/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import os
import urllib

import cv2
import numpy as np
import requests
import yaml
from PIL import Image
Expand Down Expand Up @@ -40,6 +38,9 @@ def mask_image(image, encoded_mask, transparency=60):
:param transparency: alpha transparency of masks for semantic overlays
:returns: CV2 image / numpy.ndarray matrix
"""
import cv2
import numpy as np

np_data = np.fromstring(base64.b64decode(encoded_mask), np.uint8) # type: ignore[no-overload]
mask = cv2.imdecode(np_data, cv2.IMREAD_UNCHANGED)

Expand Down Expand Up @@ -71,6 +72,8 @@ def validate_image_path(image_path):


def file2jpeg(image_path):
import cv2

img = cv2.imread(image_path)
image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pilImage = Image.fromarray(image)
Expand Down
27 changes: 22 additions & 5 deletions roboflow/util/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
import urllib.request
import warnings

import cv2
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import requests
from matplotlib import patches
from PIL import Image

from roboflow.config import (
Expand All @@ -29,6 +24,8 @@ def plot_image(image_path):
:param image_path: path of image to be plotted (can be hosted or local)
:return:
"""
import matplotlib.pyplot as plt

validate_image_path(image_path)
try:
img = Image.open(image_path)
Expand All @@ -52,6 +49,8 @@ def plot_annotation(axes, prediction=None, stroke=1, transparency=60, colors=Non
:param transparency: alpha transparency of masks for semantic overlays
:return:
"""
from matplotlib import patches

# Object Detection annotation

colors = {} if colors is None else colors
Expand Down Expand Up @@ -88,6 +87,8 @@ def plot_annotation(axes, prediction=None, stroke=1, transparency=60, colors=Non
polygon = patches.Polygon(points, linewidth=stroke, edgecolor=stroke_color, facecolor="none")
axes.add_patch(polygon)
elif prediction["prediction_type"] == SEMANTIC_SEGMENTATION_MODEL:
import matplotlib.image as mpimg

encoded_mask = prediction["segmentation_mask"]
mask_bytes = io.BytesIO(base64.b64decode(encoded_mask))
mask = mpimg.imread(mask_bytes, format="JPG")
Expand Down Expand Up @@ -121,6 +122,9 @@ def json(self):
return self.json_prediction

def __load_image(self):
import cv2
import numpy as np

if "http://" in self.image_path:
req = urllib.request.urlopen(self.image_path)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
Expand All @@ -131,6 +135,8 @@ def __load_image(self):
return cv2.imread(self.image_path)

def plot(self, stroke=1):
import matplotlib.pyplot as plt

# Exception to check if image path exists
validate_image_path(self["image_path"])
_, axes = plot_image(self["image_path"])
Expand All @@ -146,6 +152,9 @@ def save(self, output_path="predictions.jpg", stroke=2, transparency=60):
:param stroke: line width to use when drawing rectangles and polygons
:param transparency: alpha transparency of masks for semantic overlays
"""
import cv2
import numpy as np

image = self.__load_image()
stroke_color = (255, 0, 0)

Expand Down Expand Up @@ -302,6 +311,8 @@ def add_prediction(self, prediction=None):
self.predictions.append(prediction)

def plot(self, stroke=1):
import matplotlib.pyplot as plt

if len(self) > 0:
validate_image_path(self.base_image_path)
_, axes = plot_image(self.base_image_path)
Expand All @@ -311,6 +322,9 @@ def plot(self, stroke=1):
plt.show()

def __load_image(self):
import cv2
import numpy as np

# Check if it is a hosted image and open image as needed
if "http://" in self.base_image_path or "https://" in self.base_image_path:
req = urllib.request.urlopen(self.base_image_path)
Expand All @@ -322,6 +336,9 @@ def __load_image(self):
return cv2.imread(self.base_image_path)

def save(self, output_path="predictions.jpg", stroke=2):
import cv2
import numpy as np

# Load image based on image path as an array
image = self.__load_image()
stroke_color = (255, 0, 0)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_object_detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest

import numpy as np
import responses
from PIL import UnidentifiedImageError
from requests.exceptions import HTTPError
Expand Down Expand Up @@ -83,6 +82,8 @@ def test_predict_with_local_image_request(self):

@responses.activate
def test_predict_with_a_numpy_array_request(self):
import numpy as np

np_array = np.ones((100, 100, 1), dtype=np.uint8)
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)

Expand Down

0 comments on commit 3620fda

Please sign in to comment.