-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #119 from m3dev/pandas-type-check-framework
check pandas column type in dump
- Loading branch information
Showing
11 changed files
with
241 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Dict, Any | ||
|
||
import gokart | ||
import pandas as pd | ||
|
||
|
||
# Please define a class which inherits `gokart.PandasTypeConfig`. | ||
# **In practice, please import `SamplePandasTypeConfig` in `__init__`.** | ||
class SamplePandasTypeConfig(gokart.PandasTypeConfig): | ||
task_namespace = 'sample_pandas_type_check' | ||
|
||
@classmethod | ||
def type_dict(cls) -> Dict[str, Any]: | ||
return {'int_column': int} | ||
|
||
|
||
class SampleTask(gokart.TaskOnKart): | ||
# Please set the same `task_namespace` as `SamplePandasTypeConfig`. | ||
task_namespace = 'sample_pandas_type_check' | ||
|
||
def run(self): | ||
df = pd.DataFrame(dict(int_column=['a'])) | ||
self.dump(df) # This line causes PandasTypeError, because expected type is `int`, but `str` is passed. | ||
|
||
|
||
if __name__ == '__main__': | ||
gokart.run([ | ||
'sample_pandas_type_check.SampleTask', | ||
'--local-scheduler', | ||
'--rerun']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from abc import abstractmethod | ||
from logging import getLogger | ||
from typing import Dict, Any | ||
|
||
import luigi | ||
import numpy as np | ||
import pandas as pd | ||
from luigi.task_register import Register | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class PandasTypeError(Exception): | ||
pass | ||
|
||
|
||
class PandasTypeConfig(luigi.Config): | ||
@classmethod | ||
@abstractmethod | ||
def type_dict(cls) -> Dict[str, Any]: | ||
pass | ||
|
||
@classmethod | ||
def check(cls, df: pd.DataFrame): | ||
for column_name, column_type in cls.type_dict().items(): | ||
cls._check_column(df, column_name, column_type) | ||
|
||
@classmethod | ||
def _check_column(cls, df, column_name, column_type): | ||
if column_name not in df.columns: | ||
return | ||
|
||
if not np.all(list(map(lambda x: isinstance(x, column_type), df[column_name]))): | ||
not_match = next(filter(lambda x: not isinstance(x, column_type), df[column_name])) | ||
raise PandasTypeError(f'expected type is "{column_type}", but "{type(not_match)}" is passed in column "{column_name}".') | ||
|
||
|
||
class PandasTypeConfigMap(luigi.Config): | ||
"""To initialize this class only once, this inherits luigi.Config.""" | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
super(PandasTypeConfigMap, self).__init__(*args, **kwargs) | ||
task_names = Register.task_names() | ||
task_classes = [Register.get_task_cls(task_name) for task_name in task_names] | ||
self._map = {task_class.task_namespace: task_class for task_class in task_classes if | ||
issubclass(task_class, PandasTypeConfig) and task_class != PandasTypeConfig} | ||
|
||
def check(self, obj, task_namespace: str): | ||
if type(obj) == pd.DataFrame and task_namespace in self._map: | ||
self._map[task_namespace].check(obj) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import unittest | ||
from logging import getLogger | ||
from typing import Dict, Any | ||
|
||
import pandas as pd | ||
from luigi.mock import MockTarget, MockFileSystem | ||
from mock import patch | ||
|
||
import gokart | ||
from gokart.pandas_type_config import PandasTypeConfig | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class TestPandasTypeConfig(PandasTypeConfig): | ||
task_namespace = 'test_pandas_type_check_framework' | ||
|
||
@classmethod | ||
def type_dict(cls) -> Dict[str, Any]: | ||
return {'system_cd': int} | ||
|
||
|
||
class _DummyFailTask(gokart.TaskOnKart): | ||
task_namespace = 'test_pandas_type_check_framework' | ||
rerun = True | ||
|
||
def output(self): | ||
return self.make_target('dummy.pkl') | ||
|
||
def run(self): | ||
df = pd.DataFrame(dict(system_cd=['1'])) | ||
self.dump(df) | ||
|
||
|
||
class _DummyFailWithNoneTask(gokart.TaskOnKart): | ||
task_namespace = 'test_pandas_type_check_framework' | ||
rerun = True | ||
|
||
def output(self): | ||
return self.make_target('dummy.pkl') | ||
|
||
def run(self): | ||
df = pd.DataFrame(dict(system_cd=[1, None])) | ||
self.dump(df) | ||
|
||
|
||
class _DummySuccessTask(gokart.TaskOnKart): | ||
task_namespace = 'test_pandas_type_check_framework' | ||
rerun = True | ||
|
||
def output(self): | ||
return self.make_target('dummy.pkl') | ||
|
||
def run(self): | ||
df = pd.DataFrame(dict(system_cd=[1])) | ||
self.dump(df) | ||
|
||
|
||
class TestPandasTypeCheckFramework(unittest.TestCase): | ||
def setUp(self) -> None: | ||
MockFileSystem().clear() | ||
|
||
@patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummyFailTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock']) | ||
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) | ||
def test_fail(self): | ||
with self.assertRaises(SystemExit) as exit_code: | ||
gokart.run() | ||
self.assertNotEqual(exit_code.exception.code, 0) # raise Error | ||
|
||
@patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummyFailWithNoneTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock']) | ||
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) | ||
def test_fail_with_None(self): | ||
with self.assertRaises(SystemExit) as exit_code: | ||
gokart.run() | ||
self.assertNotEqual(exit_code.exception.code, 0) # raise Error | ||
|
||
@patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummySuccessTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock']) | ||
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) | ||
def test_success(self): | ||
with self.assertRaises(SystemExit) as exit_code: | ||
gokart.run() | ||
self.assertEqual(exit_code.exception.code, 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from datetime import datetime, date | ||
from typing import Dict, Any | ||
from unittest import TestCase | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from gokart import PandasTypeConfig | ||
from gokart.pandas_type_config import PandasTypeError | ||
|
||
|
||
class _DummyPandasTypeConfig(PandasTypeConfig): | ||
|
||
@classmethod | ||
def type_dict(cls) -> Dict[str, Any]: | ||
return {'int_column': int, 'datetime_column': datetime, 'array_column': np.ndarray} | ||
|
||
|
||
class TestPandasTypeConfig(TestCase): | ||
def test_int_fail(self): | ||
df = pd.DataFrame(dict(int_column=['1'])) | ||
with self.assertRaises(PandasTypeError): | ||
_DummyPandasTypeConfig().check(df) | ||
|
||
def test_int_success(self): | ||
df = pd.DataFrame(dict(int_column=[1])) | ||
_DummyPandasTypeConfig().check(df) | ||
|
||
def test_datetime_fail(self): | ||
df = pd.DataFrame(dict(datetime_column=[date(2019, 1, 12)])) | ||
with self.assertRaises(PandasTypeError): | ||
_DummyPandasTypeConfig().check(df) | ||
|
||
def test_datetime_success(self): | ||
df = pd.DataFrame(dict(datetime_column=[datetime(2019, 1, 12, 0, 0, 0)])) | ||
_DummyPandasTypeConfig().check(df) | ||
|
||
def test_array_fail(self): | ||
df = pd.DataFrame(dict(array_column=[[1, 2]])) | ||
with self.assertRaises(PandasTypeError): | ||
_DummyPandasTypeConfig().check(df) | ||
|
||
def test_array_success(self): | ||
df = pd.DataFrame(dict(array_column=[np.array([1, 2])])) | ||
_DummyPandasTypeConfig().check(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters