diff --git a/qlib/data/data.py b/qlib/data/data.py index 809b8d1c32..1a6f3ee6ed 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -587,7 +587,10 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i if len(new_data) > 0: data = pd.concat(new_data, names=["instrument"], sort=False) - data = DiskDatasetCache.cache_to_origin_data(data, column_names) + + # NOTE: InstProcessors may add new columns and using cache_to_origin_data will remove those added columns. + if not len(inst_processors): + data = DiskDatasetCache.cache_to_origin_data(data, column_names) else: data = pd.DataFrame( index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index e9d6f98866..3d0fd95465 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -10,7 +10,13 @@ from typing import Tuple, Union, List from qlib.data import D -from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point +from qlib.utils import ( + load_dataset, + init_instance_by_config, + time_to_slc_point, + remove_fields_space, + remove_repeat_field, +) from qlib.log import get_module_logger from qlib.utils.serial import Serializable @@ -215,7 +221,12 @@ def load_group_df( self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, []) ) df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors) - df.columns = names + # NOTE: InstProcessors may add new columns + if len(inst_processors): + df.rename(columns=dict(zip(remove_repeat_field(remove_fields_space(exprs)), names)), inplace=True) + else: + df.columns = names + if self.swap_level: df = df.swaplevel().sort_index() # NOTE: if swaplevel, return return df