Skip to content

Commit

Permalink
Merge pull request #119 from m3dev/pandas-type-check-framework
Browse files Browse the repository at this point in the history
check pandas column type in dump
  • Loading branch information
nishiba authored Feb 3, 2020
2 parents 1085de3 + 217ed8c commit 14db6bc
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 10 deletions.
30 changes: 30 additions & 0 deletions examples/sample_pandas_type_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Dict, Any

import gokart
import pandas as pd


# Please define a class which inherits `gokart.PandasTypeConfig`.
# **In practice, please import `SamplePandasTypeConfig` in `__init__`.**
class SamplePandasTypeConfig(gokart.PandasTypeConfig):
task_namespace = 'sample_pandas_type_check'

@classmethod
def type_dict(cls) -> Dict[str, Any]:
return {'int_column': int}


class SampleTask(gokart.TaskOnKart):
# Please set the same `task_namespace` as `SamplePandasTypeConfig`.
task_namespace = 'sample_pandas_type_check'

def run(self):
df = pd.DataFrame(dict(int_column=['a']))
self.dump(df) # This line causes PandasTypeError, because expected type is `int`, but `str` is passed.


if __name__ == '__main__':
gokart.run([
'sample_pandas_type_check.SampleTask',
'--local-scheduler',
'--rerun'])
1 change: 1 addition & 0 deletions gokart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from gokart.pandas_type_config import PandasTypeConfig
from gokart.parameter import TaskInstanceParameter, ListTaskInstanceParameter
from gokart.task import TaskOnKart
from gokart.info import make_tree_info, tree_info
Expand Down
50 changes: 50 additions & 0 deletions gokart/pandas_type_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from abc import abstractmethod
from logging import getLogger
from typing import Dict, Any

import luigi
import numpy as np
import pandas as pd
from luigi.task_register import Register

logger = getLogger(__name__)


class PandasTypeError(Exception):
pass


class PandasTypeConfig(luigi.Config):
@classmethod
@abstractmethod
def type_dict(cls) -> Dict[str, Any]:
pass

@classmethod
def check(cls, df: pd.DataFrame):
for column_name, column_type in cls.type_dict().items():
cls._check_column(df, column_name, column_type)

@classmethod
def _check_column(cls, df, column_name, column_type):
if column_name not in df.columns:
return

if not np.all(list(map(lambda x: isinstance(x, column_type), df[column_name]))):
not_match = next(filter(lambda x: not isinstance(x, column_type), df[column_name]))
raise PandasTypeError(f'expected type is "{column_type}", but "{type(not_match)}" is passed in column "{column_name}".')


class PandasTypeConfigMap(luigi.Config):
"""To initialize this class only once, this inherits luigi.Config."""

def __init__(self, *args, **kwargs) -> None:
super(PandasTypeConfigMap, self).__init__(*args, **kwargs)
task_names = Register.task_names()
task_classes = [Register.get_task_cls(task_name) for task_name in task_names]
self._map = {task_class.task_namespace: task_class for task_class in task_classes if
issubclass(task_class, PandasTypeConfig) and task_class != PandasTypeConfig}

def check(self, obj, task_namespace: str):
if type(obj) == pd.DataFrame and task_namespace in self._map:
self._map[task_namespace].check(obj)
3 changes: 3 additions & 0 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import luigi
import pandas as pd
from luigi.task_register import Register

import gokart
from gokart.file_processor import FileProcessor
from gokart.pandas_type_config import PandasTypeConfigMap
from gokart.target import TargetOnKart

logger = getLogger(__name__)
Expand Down Expand Up @@ -190,6 +192,7 @@ def _pd_concat(dfs):
return data

def dump(self, obj, target: Union[None, str, TargetOnKart] = None) -> None:
PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace)
self._get_output_target(target).dump(obj)

def make_unique_id(self):
Expand Down
File renamed without changes.
18 changes: 14 additions & 4 deletions test/test_info.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import unittest
from unittest.mock import patch, MagicMock
from unittest.mock import patch

import luigi
import luigi.mock
from luigi.mock import MockFileSystem, MockTarget
from luigi.task_register import Register

import gokart
import gokart.info

Expand Down Expand Up @@ -34,8 +37,11 @@ def run(self):


class TestInfo(unittest.TestCase):
@patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs))
def test_make_tree_info(self):
def setUp(self) -> None:
MockFileSystem().clear()

@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_pending(self):
task = _Task(param=1, sub=_SubTask(param=2))

# check before running
Expand All @@ -45,8 +51,12 @@ def test_make_tree_info(self):
└─-\(PENDING\) _SubTask\[[a-z0-9]*\]"""
self.assertRegex(tree, expected)

@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_complete(self):
task = _Task(param=1, sub=_SubTask(param=2))

# check after sub task runs
luigi.build([task], local_scheduler=True, log_level='CRITICAL')
luigi.build([task], local_scheduler=True)
tree = gokart.info.make_tree_info(task)
expected = r"""
└─-\(COMPLETE\) _Task\[[a-z0-9]*\]
Expand Down
82 changes: 82 additions & 0 deletions test/test_pandas_type_check_framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import unittest
from logging import getLogger
from typing import Dict, Any

import pandas as pd
from luigi.mock import MockTarget, MockFileSystem
from mock import patch

import gokart
from gokart.pandas_type_config import PandasTypeConfig

logger = getLogger(__name__)


class TestPandasTypeConfig(PandasTypeConfig):
task_namespace = 'test_pandas_type_check_framework'

@classmethod
def type_dict(cls) -> Dict[str, Any]:
return {'system_cd': int}


class _DummyFailTask(gokart.TaskOnKart):
task_namespace = 'test_pandas_type_check_framework'
rerun = True

def output(self):
return self.make_target('dummy.pkl')

def run(self):
df = pd.DataFrame(dict(system_cd=['1']))
self.dump(df)


class _DummyFailWithNoneTask(gokart.TaskOnKart):
task_namespace = 'test_pandas_type_check_framework'
rerun = True

def output(self):
return self.make_target('dummy.pkl')

def run(self):
df = pd.DataFrame(dict(system_cd=[1, None]))
self.dump(df)


class _DummySuccessTask(gokart.TaskOnKart):
task_namespace = 'test_pandas_type_check_framework'
rerun = True

def output(self):
return self.make_target('dummy.pkl')

def run(self):
df = pd.DataFrame(dict(system_cd=[1]))
self.dump(df)


class TestPandasTypeCheckFramework(unittest.TestCase):
def setUp(self) -> None:
MockFileSystem().clear()

@patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummyFailTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock'])
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_fail(self):
with self.assertRaises(SystemExit) as exit_code:
gokart.run()
self.assertNotEqual(exit_code.exception.code, 0) # raise Error

@patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummyFailWithNoneTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock'])
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_fail_with_None(self):
with self.assertRaises(SystemExit) as exit_code:
gokart.run()
self.assertNotEqual(exit_code.exception.code, 0) # raise Error

@patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummySuccessTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock'])
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_success(self):
with self.assertRaises(SystemExit) as exit_code:
gokart.run()
self.assertEqual(exit_code.exception.code, 0)
45 changes: 45 additions & 0 deletions test/test_pandas_type_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from datetime import datetime, date
from typing import Dict, Any
from unittest import TestCase

import numpy as np
import pandas as pd

from gokart import PandasTypeConfig
from gokart.pandas_type_config import PandasTypeError


class _DummyPandasTypeConfig(PandasTypeConfig):

@classmethod
def type_dict(cls) -> Dict[str, Any]:
return {'int_column': int, 'datetime_column': datetime, 'array_column': np.ndarray}


class TestPandasTypeConfig(TestCase):
def test_int_fail(self):
df = pd.DataFrame(dict(int_column=['1']))
with self.assertRaises(PandasTypeError):
_DummyPandasTypeConfig().check(df)

def test_int_success(self):
df = pd.DataFrame(dict(int_column=[1]))
_DummyPandasTypeConfig().check(df)

def test_datetime_fail(self):
df = pd.DataFrame(dict(datetime_column=[date(2019, 1, 12)]))
with self.assertRaises(PandasTypeError):
_DummyPandasTypeConfig().check(df)

def test_datetime_success(self):
df = pd.DataFrame(dict(datetime_column=[datetime(2019, 1, 12, 0, 0, 0)]))
_DummyPandasTypeConfig().check(df)

def test_array_fail(self):
df = pd.DataFrame(dict(array_column=[[1, 2]]))
with self.assertRaises(PandasTypeError):
_DummyPandasTypeConfig().check(df)

def test_array_success(self):
df = pd.DataFrame(dict(array_column=[np.array([1, 2])]))
_DummyPandasTypeConfig().check(df)
5 changes: 4 additions & 1 deletion test/test_restore_task_by_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ def run(self):
self.dump('test')


@patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs))
class RestoreTaskByIDTest(unittest.TestCase):
def setUp(self) -> None:
luigi.mock.MockFileSystem().clear()

@patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs))
def test(self):
task = _DummyTask(sub_task=_SubDummyTask(param=10))
luigi.build([task], local_scheduler=True, log_level="CRITICAL")
Expand Down
7 changes: 4 additions & 3 deletions test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ class _DummyTask(gokart.TaskOnKart):
class RunTest(unittest.TestCase):
def setUp(self):
luigi.configuration.LuigiConfigParser._instance = None
luigi.mock.MockFileSystem().clear()
os.environ.clear()

@patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--param', 'test', '--log-level=CRITICAL', '--local-scheduler'])
def test_run(self):
config_file_path = os.path.join(os.path.dirname(__name__), 'test_config.ini')
config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini')
luigi.configuration.LuigiConfigParser.add_config_path(config_file_path)
os.environ.setdefault('test_param', 'test')
with self.assertRaises(SystemExit) as exit_code:
Expand All @@ -29,15 +30,15 @@ def test_run(self):

@patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--log-level=CRITICAL', '--local-scheduler'])
def test_run_with_undefined_environ(self):
config_file_path = os.path.join(os.path.dirname(__name__), 'test_config.ini')
config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini')
luigi.configuration.LuigiConfigParser.add_config_path(config_file_path)
with self.assertRaises(luigi.parameter.MissingParameterException) as missing_parameter:
gokart.run()

@patch('sys.argv', new=['main', '--tree-info-mode=simple', '--tree-info-output-path=tree.txt', f'{__name__}._DummyTask', '--param', 'test', '--log-level=CRITICAL', '--local-scheduler'])
@patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs))
def test_run_tree_info(self):
config_file_path = os.path.join(os.path.dirname(__name__), 'test_config.ini')
config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini')
luigi.configuration.LuigiConfigParser.add_config_path(config_file_path)
os.environ.setdefault('test_param', 'test')
tree_info = gokart.tree_info(mode='simple', output_path='tree.txt')
Expand Down
10 changes: 8 additions & 2 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ class _DummyTask(gokart.TaskOnKart):
param = luigi.IntParameter(default=1)
list_param = luigi.ListParameter(default=['a', 'b'])
bool_param = luigi.BoolParameter()

def output(self):
return None


class _DummyTaskA(gokart.TaskOnKart):
task_namespace = __name__

def output(self):
return None


@inherits(_DummyTaskA)
class _DummyTaskB(gokart.TaskOnKart):
task_namespace = __name__

def output(self):
return None

Expand All @@ -40,15 +43,18 @@ def requires(self):
@inherits(_DummyTaskB)
class _DummyTaskC(gokart.TaskOnKart):
task_namespace = __name__

def output(self):
return None

def requires(self):
return self.clone(_DummyTaskB)


class _DummyTaskD(gokart.TaskOnKart):
task_namespace = __name__


class TaskTest(unittest.TestCase):
def setUp(self):
_DummyTask.clear_instance_cache()
Expand Down Expand Up @@ -266,10 +272,10 @@ def test_use_rerun_with_inherits(self):
def test_significant_flag(self):
def _make_task(significant: bool, has_required_task: bool):
class _MyDummyTaskA(gokart.TaskOnKart):
task_namespace = __name__
task_namespace = f'{__name__}_{significant}_{has_required_task}'

class _MyDummyTaskB(gokart.TaskOnKart):
task_namespace = __name__
task_namespace = f'{__name__}_{significant}_{has_required_task}'

def requires(self):
if has_required_task:
Expand Down

0 comments on commit 14db6bc

Please sign in to comment.