Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: wukan1986/expr_codegen
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.10.9
Choose a base ref
...
head repository: wukan1986/expr_codegen
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: main
Choose a head ref
  • 6 commits
  • 8 files changed
  • 1 contributor

Commits on Jan 9, 2025

  1. 分钟数据示例

    wukan1986 committed Jan 9, 2025
    Copy the full SHA
    cdcff0f View commit details

Commits on Jan 11, 2025

  1. Copy the full SHA
    43ccc91 View commit details

Commits on Jan 12, 2025

  1. 添加日志

    wukan1986 committed Jan 12, 2025
    Copy the full SHA
    53c2166 View commit details

Commits on Feb 17, 2025

  1. 修复模板中隐含不足

    wukan1986 committed Feb 17, 2025
    Copy the full SHA
    44ceb9b View commit details

Commits on Feb 18, 2025

  1. Copy the full SHA
    3e57ba0 View commit details

Commits on Mar 20, 2025

  1. 修复cse失败的bug

    wukan1986 committed Mar 20, 2025
    Copy the full SHA
    38e24d4 View commit details
22 changes: 18 additions & 4 deletions examples/demo_min.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
np.random.seed(42)

ASSET_COUNT = 500
DATE_COUNT = 250 * 24 * 60 * 1
DATE_COUNT = 250 * 24 * 10 * 1
DATE = pd.date_range(datetime(2020, 1, 1), periods=DATE_COUNT, freq='1min').repeat(ASSET_COUNT)
ASSET = [f'A{i:04d}' for i in range(ASSET_COUNT)] * DATE_COUNT

@@ -66,8 +66,23 @@
# !!!使用时一定要分清分组是用哪个字段
date='datetime', asset='_asset_date')
# ---
logger.info('1分钟转15分钟线开始')
df1 = df.sort('asset', 'datetime').group_by_dynamic('datetime', every="15m", closed='left', label="left", group_by=['asset', 'trading_day']).agg(
open_dt=pl.first("datetime"),
close_dt=pl.last("datetime"),
OPEN=pl.first("OPEN"),
HIGH=pl.max("HIGH"),
LOW=pl.min("LOW"),
CLOSE=pl.last("CLOSE"),
VOLUME=pl.sum("VOLUME"),
OPEN_INTEREST=pl.last("OPEN_INTEREST"),
)
logger.info('1分钟转15分钟线结束')
print(df1)
# ---
logger.info('1分钟转日线开始')
df = df.sort('asset', 'datetime').group_by('asset', 'trading_day', maintain_order=True).agg(
# 也可以使用group_by_dynamic,只是日线隐含了label="left"
df1 = df.sort('asset', 'datetime').group_by('asset', 'trading_day', maintain_order=True).agg(
open_dt=pl.first("datetime"),
close_dt=pl.last("datetime"),
OPEN=pl.first("OPEN"),
@@ -78,5 +93,4 @@
OPEN_INTEREST=pl.last("OPEN_INTEREST"),
)
logger.info('1分钟转日线结束')
print(df)
# df.write_csv('output.csv')
print(df1)
10 changes: 7 additions & 3 deletions examples/demo_tdx.py
Original file line number Diff line number Diff line change
@@ -100,11 +100,15 @@ def _code_block_2():
# =====================================
logger.info('计算开始')
t1 = time.perf_counter()
df = codegen_exec(df.lazy(), _code_block_1, _code_block_2, output_file=sys.stdout)
df = codegen_exec(df, _code_block_1, _code_block_2, output_file='1_out.py', run_file=False, over_null=None)
t2 = time.perf_counter()
print(t2 - t1)
df = codegen_exec(df, _code_block_1, _code_block_2, output_file='1_out.py', run_file=True, over_null=None)
t3 = time.perf_counter()
df = codegen_exec(df, _code_block_1, _code_block_2, output_file='1_out.py', run_file=True, over_null=None)
t4 = time.perf_counter()
print(t2 - t1, t3 - t2, t4 - t3)
logger.info('计算结束')
df = df.filter(
~pl.col('is_st'),
)
print(df.collect())
print(df)
2 changes: 1 addition & 1 deletion expr_codegen/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.10.9"
__version__ = "0.10.14"
14 changes: 14 additions & 0 deletions expr_codegen/model.py
Original file line number Diff line number Diff line change
@@ -137,13 +137,27 @@ def chain_create(nested_list):
last_min = float('inf')
# 最小不重复的一行记录
last_row = None
last_rows = set()
for row in product(*neighbor_inter):
# 判断两两是否重复,重复为1,反之为0
result = sum([x == y for x, y in zip(row[:-1], row[1:])])
if last_min > result:
last_min = result
last_row = row
if result == 0:
last_rows.add(last_row)
last_min = float('inf')
continue
last_rows.add(last_row)
last_rows = list(last_rows)

# last_rows中有多个满足条件的,优先保证最后一组ts在最前,ts后可提前filter减少计算量
last_row = last_rows[0]
for row in last_rows:
if row[-1] is None:
continue
if row[-1][0] == 'ts':
last_row = row
break

# 如何移动才是难点 如果两个连续 ts/ts,那么如何移动
7 changes: 4 additions & 3 deletions expr_codegen/pandas/template.py.j2
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# this code is auto generated by the expr_codegen
# https://github.com/wukan1986/expr_codegen
# 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request
from typing import Tuple

import numpy as np # noqa
import pandas as pd # noqa
@@ -63,6 +64,6 @@ def main(df: pd.DataFrame) -> pd.DataFrame:

return df

if __name__ in ("__main__", "builtins"):
# TODO: 数据加载或外部传入
df_output = main(df_input)
# if __name__ in ("__main__", "builtins"):
# # TODO: 数据加载或外部传入
# df_output = main(df_input)
8 changes: 4 additions & 4 deletions expr_codegen/polars_group/template.py.j2
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ _NONE_ = None
_TRUE_ = True
_FALSE_ = False

def unpack(x: Expr, idx: int = 0) -> Expr:
def unpack(x: pl.Expr, idx: int = 0) -> pl.Expr:
return x.struct[idx]

{%-for row in extra_codes %}
@@ -78,6 +78,6 @@ def main(df: DataFrame) -> DataFrame:

return df

if __name__ in ("__main__", "builtins"):
# TODO: 数据加载或外部传入
df_output = main(df_input)
# if __name__ in ("__main__", "builtins"):
# # TODO: 数据加载或外部传入
# df_output = main(df_input)
8 changes: 4 additions & 4 deletions expr_codegen/polars_over/template.py.j2
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ _NONE_ = None
_TRUE_ = True
_FALSE_ = False

def unpack(x: Expr, idx: int = 0) -> Expr:
def unpack(x: pl.Expr, idx: int = 0) -> pl.Expr:
return x.struct[idx]

{%-for row in extra_codes %}
@@ -78,6 +78,6 @@ def main(df: DataFrame) -> DataFrame:

return df

if __name__ in ("__main__", "builtins"):
# TODO: 数据加载或外部传入
df_output = main(df_input)
# if __name__ in ("__main__", "builtins"):
# # TODO: 数据加载或外部传入
# df_output = main(df_input)
61 changes: 39 additions & 22 deletions expr_codegen/tool.py
Original file line number Diff line number Diff line change
@@ -82,9 +82,6 @@ def extract(self, expr, date, asset):
表达式列表
"""
# 抽取前先化简
expr = simplify2(expr)

exprs = []
syms = []
get_children(self.get_current_func, self.get_current_func_kwargs,
@@ -110,6 +107,9 @@ def merge(self, date, asset, **kwargs):
-------
表达式列表
"""
# 抽取前先化简
kwargs = {k: simplify2(v) for k, v in kwargs.items()}

exprs_syms = [self.extract(v, date, asset) for v in kwargs.values()]
exprs = []
syms = []
@@ -120,6 +120,7 @@ def merge(self, date, asset, **kwargs):
syms = sorted(set(syms), key=syms.index)
# 如果目标有重复表达式,这里会混乱
exprs = sorted(set(exprs), key=exprs.index)
# 这里不能合并简化与未简化的表达式,会导致cse时失败,需要简化表达式合并
exprs = exprs + list(kwargs.values())

# print(exprs)
@@ -257,6 +258,8 @@ def all(self, exprs_src, style: Literal['pandas', 'polars_group', 'polars_over']
extra_codes=extra_codes,
**kwargs)

logger.info(f'code is generated')

if format:
# 格式化。在遗传算法中没有必要
codes = format_str(codes, mode=Mode(line_length=600, magic_trailing_comma=True))
@@ -299,25 +302,30 @@ def _get_code(self,
return code


def _exec_code(code: str, df_input):
globals_ = {'df_input': df_input}
@lru_cache(maxsize=64, typed=True)
def _get_func_from_code(code: str):
logger.info(f'get func from code')
globals_ = {}
exec(code, globals_)
return globals_['df_output']

return globals_['main']

def _exec_file(file, df_input):
file = pathlib.Path(file)
logger.info(f'run file "{file.absolute()}"')
with open(file, 'r', encoding='utf-8') as f:
code = f.read()
return _exec_code(code, df_input)


def _exec_module(module: str, df_input):
@lru_cache(maxsize=64, typed=True)
def _get_func_from_module(module: str):
""""可下断点调试"""
m = __import__(module, fromlist=['*'])
logger.info(f'run module {m}')
return m.main(df_input)
logger.info(f'get func from module {m}')
return m.main


@lru_cache(maxsize=64, typed=True)
def _get_func_from_file(file: str):
file = pathlib.Path(file)
logger.info(f'get func from file "{file.absolute()}"')
with open(file, 'r', encoding='utf-8') as f:
globals_ = {}
exec(f.read(), globals_)
return globals_['main']


_TOOL_ = ExprTool()
@@ -347,7 +355,7 @@ def codegen_exec(df: Optional[DataFrame],
output_file: str| TextIOBase
保存生成的目标代码到文件中
run_file: bool or str
是否不生成脚本,直接运行代码。
是否不生成脚本,直接运行代码。注意:带缓存功能,多次调用不重复生成
- 如果是True,会自动从output_file中读取代码
- 如果是字符串,会自动从run_file中读取代码
- 如果是模块名,会自动从模块中读取代码(可调试)
@@ -374,17 +382,25 @@ def codegen_exec(df: Optional[DataFrame],
-------
DataFrame
Notes
-----
处处都有缓存,所以在公式研发阶段要留意日志输出。以免一直调试的旧公式
1. 确保重新生成了代码 `code is generated`
2. 通过代码得到了函数 `get func from code`
"""
if df is not None:
# 以下代码都带缓存功能
if run_file is True:
assert output_file is not None, 'output_file is required'
return _exec_file(output_file, df)
return _get_func_from_file(output_file)(df)
if run_file is not False:
run_file = str(run_file)
if run_file.endswith('.py'):
return _exec_file(run_file, df)
return _get_func_from_file(run_file)(df)
else:
return _exec_module(run_file, df) # 可断点调试
return _get_func_from_module(run_file)(df) # 可断点调试

# 此代码来自于sympy.var
frame = inspect.currentframe().f_back
@@ -407,4 +423,5 @@ def codegen_exec(df: Optional[DataFrame],
if df is None:
return None
else:
return _exec_code(code, df)
# 代码一样时就从缓存中取出函数
return _get_func_from_code(code)(df)