From c1121b961d440f5ace52b1edd729323b5dfd24f9 Mon Sep 17 00:00:00 2001 From: lixu Date: Sun, 5 Feb 2023 09:35:25 +0800 Subject: [PATCH 1/7] Fix: QlibDataLoader drops the cols added by inst_processor --- qlib/data/dataset/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index cc9ecd7c41..e7d9c9c143 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -211,7 +211,7 @@ def load_group_df( df = D.features( instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, []) ) - df.columns = names + df.rename(columns=dict(zip(exprs, names)), inplace=True) if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df From 762f75e5d4323eeebf5ba0e54e04f64d8ecc4ea9 Mon Sep 17 00:00:00 2001 From: lixu Date: Sun, 5 Feb 2023 10:19:38 +0800 Subject: [PATCH 2/7] Fix: QlibDataLoader drops the cols added by inst_processor --- qlib/data/data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 73edf9f010..005849cf00 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -587,7 +587,6 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i if len(new_data) > 0: data = pd.concat(new_data, names=["instrument"], sort=False) - data = DiskDatasetCache.cache_to_origin_data(data, column_names) else: data = pd.DataFrame( index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), From 8d2a5dd231a2d55c71f4dfa7cff2a660ca93c91b Mon Sep 17 00:00:00 2001 From: Xu Li Date: Mon, 6 Feb 2023 21:04:33 +0800 Subject: [PATCH 3/7] Fix: normalize field names before replacing the column names --- qlib/data/dataset/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index e7d9c9c143..ff6abeec83 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -10,7 +10,7 @@ from typing import Tuple, Union, List from qlib.data import D -from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point +from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, normalize_cache_fields from qlib.log import get_module_logger from qlib.utils.serial import Serializable @@ -211,7 +211,7 @@ def load_group_df( df = D.features( instruments, exprs, start_time, end_time, freq=freq, inst_processors=self.inst_processor.get(gp_name, []) ) - df.rename(columns=dict(zip(exprs, names)), inplace=True) + df.rename(columns=dict(zip(normalize_cache_fields(exprs), names)), inplace=True) if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df From 32b498afd4d7618ee1ab0eac75be8a2deeae9a04 Mon Sep 17 00:00:00 2001 From: Xu Li Date: Sun, 7 May 2023 11:55:47 +0800 Subject: [PATCH 4/7] Only use the new logic when InstProcessor presents --- qlib/data/data.py | 4 +++ qlib/data/dataset/loader.py | 50 +++++++++++++++++++++---------------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 11c7e5b992..1a6f3ee6ed 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -587,6 +587,10 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i if len(new_data) > 0: data = pd.concat(new_data, names=["instrument"], sort=False) + + # NOTE: InstProcessors may add new columns and using cache_to_origin_data will remove those added columns. + if not len(inst_processors): + data = DiskDatasetCache.cache_to_origin_data(data, column_names) else: data = pd.DataFrame( index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index cb12cee602..7afabb83f9 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -10,7 +10,8 @@ from typing import Tuple, Union, List from qlib.data import D -from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, normalize_cache_fields +from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, normalize_cache_fields, \ + remove_fields_space, remove_repeat_field from qlib.log import get_module_logger from qlib.utils.serial import Serializable @@ -103,13 +104,13 @@ def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, lis @abc.abstractmethod def load_group_df( - self, - instruments, - exprs: list, - names: list, - start_time: Union[str, pd.Timestamp] = None, - end_time: Union[str, pd.Timestamp] = None, - gp_name: str = None, + self, + instruments, + exprs: list, + names: list, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + gp_name: str = None, ) -> pd.DataFrame: """ load the dataframe for specific group @@ -148,12 +149,12 @@ class QlibDataLoader(DLWParser): """Same as QlibDataLoader. The fields can be define by config""" def __init__( - self, - config: Tuple[list, tuple, dict], - filter_pipe: List = None, - swap_level: bool = True, - freq: Union[str, dict] = "day", - inst_processors: Union[dict, list] = None, + self, + config: Tuple[list, tuple, dict], + filter_pipe: List = None, + swap_level: bool = True, + freq: Union[str, dict] = "day", + inst_processors: Union[dict, list] = None, ): """ Parameters @@ -194,13 +195,13 @@ def __init__( ), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty" def load_group_df( - self, - instruments, - exprs: list, - names: list, - start_time: Union[str, pd.Timestamp] = None, - end_time: Union[str, pd.Timestamp] = None, - gp_name: str = None, + self, + instruments, + exprs: list, + names: list, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + gp_name: str = None, ) -> pd.DataFrame: if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") @@ -215,7 +216,12 @@ def load_group_df( self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, []) ) df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors) - df.rename(columns=dict(zip(normalize_cache_fields(exprs), names)), inplace=True) + # NOTE: InstProcessors may add new columns + if len(inst_processors): + df.rename(columns=dict(zip(remove_repeat_field(remove_fields_space(exprs)), names)), inplace=True) + else: + df.columns = names + if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df From 643859cdad93cdd4563c4ca45c5cc0d07cbed8c8 Mon Sep 17 00:00:00 2001 From: Cadenza-Li <362237642@qq.com> Date: Fri, 25 Aug 2023 11:51:17 +0800 Subject: [PATCH 5/7] activate CI --- qlib/data/dataset/loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 7afabb83f9..9f0e4752fd 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -216,6 +216,7 @@ def load_group_df( self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, []) ) df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors) + # NOTE: InstProcessors may add new columns if len(inst_processors): df.rename(columns=dict(zip(remove_repeat_field(remove_fields_space(exprs)), names)), inplace=True) From d302c0637896c00a3f92df7687f9116e5fbaa7ce Mon Sep 17 00:00:00 2001 From: Cadenza-Li <128388363+Fivele-Li@users.noreply.github.com> Date: Fri, 25 Aug 2023 12:55:44 +0800 Subject: [PATCH 6/7] fix black; --- qlib/data/dataset/loader.py | 50 ++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 9f0e4752fd..6aa7edc524 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -10,8 +10,13 @@ from typing import Tuple, Union, List from qlib.data import D -from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point, normalize_cache_fields, \ - remove_fields_space, remove_repeat_field +from qlib.utils import ( + load_dataset, + init_instance_by_config, + time_to_slc_point, + remove_fields_space, + remove_repeat_field +) from qlib.log import get_module_logger from qlib.utils.serial import Serializable @@ -104,13 +109,13 @@ def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, lis @abc.abstractmethod def load_group_df( - self, - instruments, - exprs: list, - names: list, - start_time: Union[str, pd.Timestamp] = None, - end_time: Union[str, pd.Timestamp] = None, - gp_name: str = None, + self, + instruments, + exprs: list, + names: list, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + gp_name: str = None, ) -> pd.DataFrame: """ load the dataframe for specific group @@ -149,12 +154,12 @@ class QlibDataLoader(DLWParser): """Same as QlibDataLoader. The fields can be define by config""" def __init__( - self, - config: Tuple[list, tuple, dict], - filter_pipe: List = None, - swap_level: bool = True, - freq: Union[str, dict] = "day", - inst_processors: Union[dict, list] = None, + self, + config: Tuple[list, tuple, dict], + filter_pipe: List = None, + swap_level: bool = True, + freq: Union[str, dict] = "day", + inst_processors: Union[dict, list] = None, ): """ Parameters @@ -195,13 +200,13 @@ def __init__( ), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty" def load_group_df( - self, - instruments, - exprs: list, - names: list, - start_time: Union[str, pd.Timestamp] = None, - end_time: Union[str, pd.Timestamp] = None, - gp_name: str = None, + self, + instruments, + exprs: list, + names: list, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + gp_name: str = None, ) -> pd.DataFrame: if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") @@ -216,7 +221,6 @@ def load_group_df( self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, []) ) df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors) - # NOTE: InstProcessors may add new columns if len(inst_processors): df.rename(columns=dict(zip(remove_repeat_field(remove_fields_space(exprs)), names)), inplace=True) From f6ed759a6107bbd313301d3d3d6106912d7a2276 Mon Sep 17 00:00:00 2001 From: Cadenza-Li <128388363+Fivele-Li@users.noreply.github.com> Date: Fri, 25 Aug 2023 13:09:09 +0800 Subject: [PATCH 7/7] fix black; --- qlib/data/dataset/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 6aa7edc524..3d0fd95465 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -15,7 +15,7 @@ init_instance_by_config, time_to_slc_point, remove_fields_space, - remove_repeat_field + remove_repeat_field, ) from qlib.log import get_module_logger from qlib.utils.serial import Serializable