Skip to content

Commit a36502a

Browse files
committedNov 17, 2024
支持赋值时元组解包
1 parent 7d24549 commit a36502a

File tree

10 files changed

+93
-12
lines changed

10 files changed

+93
-12
lines changed
 

‎README.md

+8
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,14 @@ df = codegen_exec(df, _code_block_1, _code_block_2) # 只执行,不保存代
157157
7. 支持`A[0]+B[1]+C[2]`,底层会转成`A+ts_delay(B,1)+ts_delay(C,2)`
158158
8. 支持`~A`,底层会转换成`Not(A)`
159159
9. `gp_`开头的函数都会返回对应的`cs_`函数。如`gp_func(A,B,C)`会替换成`cs_func(B,C)`,其中`A`用在了`groupby([date, A])`
160+
10. 支持`A,B,C=MACD()`元组解包,在底层会替换成
161+
162+
```python
163+
_x_0 = MACD()
164+
A = unpack(_x_0, 0)
165+
B = unpack(_x_0, 1)
166+
C = unpack(_x_0, 2)
167+
```
160168

161169
## 下划线开头的变量
162170

‎expr_codegen/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.9.2"
1+
__version__ = "0.10.0"

‎expr_codegen/codes.py

+17
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,23 @@ def __init__(self, convert_xor):
1414
# ^ 是异或还是乘方呢?
1515
self.convert_xor = convert_xor
1616

17+
def visit_Assign(self, node):
18+
t = node.targets[0]
19+
nodes = []
20+
if isinstance(t, ast.Tuple):
21+
for i, dim in enumerate(t.dims):
22+
_v = ast.Call(
23+
func=ast.Name(id='unpack', ctx=ast.Load()),
24+
args=[node.value, ast.Constant(i)],
25+
keywords=[],
26+
)
27+
n = ast.Assign([dim], _v, ctx=ast.Load())
28+
nodes.append(n)
29+
return nodes
30+
31+
self.generic_visit(node)
32+
return node
33+
1734
def visit_Compare(self, node):
1835
assert len(node.comparators) == 1, f"不支持连续等号,请手工添加括号, {ast.unparse(node)}"
1936

‎expr_codegen/pandas/code.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,19 @@ def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst,
6767
exprs_dst.append(f"#" + '=' * 40 + func_name)
6868
else:
6969
va, ex, sym = kv
70-
func_code.append(f" # {va} = {ex}\n df[{va}] = {p.doprint(ex)}")
70+
func_code.append(f" # {va} = {ex}\n g[{va}] = {p.doprint(ex)}")
7171
exprs_dst.append(f"{va} = {ex}")
7272
if va not in syms_dst:
7373
syms_out.append(va)
7474

75+
if len(groupbys['sort']) == 0:
76+
groupbys['sort'] = f'df = df.sort_values(by=[_ASSET_, _DATE_]).reset_index(drop=True)'
7577
if k[0] == TS:
76-
if len(groupbys['sort']) == 0:
77-
groupbys['sort'] = f'df = df.sort_values(by=[_ASSET_, _DATE_]).reset_index(drop=True)'
7878
# 时序需要排序
79-
func_code = [f' df = df.sort_values(by=[_DATE_])'] + func_code
80-
elif k[0] == CS:
81-
if len(groupbys['sort']) == 0:
82-
groupbys['sort'] = f'df = df.sort_values(by=[_DATE_, _ASSET_]).reset_index(drop=True)'
79+
func_code = [f' g.df = df.sort_values(by=[_DATE_])'] + func_code
80+
else:
81+
# 时序需要排序
82+
func_code = [f' g.df = df'] + func_code
8383

8484
# polars风格代码列表
8585
funcs[func_name] = '\n'.join(func_code)

‎expr_codegen/pandas/helper.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
A、B、C=MACD()无法生成DAG,所以变通的改成
3+
4+
A=unpack(MACD(),0)
5+
B=unpack(MACD(),1)
6+
C=unpack(MACD(),2)
7+
8+
cse能自动提取成
9+
10+
_x_0 = MACD()
11+
12+
但 df['_x_0'] 是无法放入tuple的,所以决定用另一个类来实现兼容
13+
14+
"""
15+
import pandas as pd
16+
17+
18+
class GlobalVariable(object):
19+
def __init__(self):
20+
self.dict = {}
21+
self.df = pd.DataFrame()
22+
23+
def __getitem__(self, item):
24+
if item in self.dict:
25+
return self.dict[item]
26+
return self.df[item]
27+
28+
def __setitem__(self, key, value):
29+
if isinstance(value, tuple):
30+
# tuple存字典中
31+
self.dict[key] = value
32+
# 占位,避免drop时报错
33+
self.df[key] = False
34+
else:
35+
self.df[key] = value

‎expr_codegen/pandas/printer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _print(self, expr, **kwargs) -> str:
5454
self._print_level -= 1
5555

5656
def _print_Symbol(self, expr):
57-
return f"df[{expr.name}]"
57+
return f"g[{expr.name}]"
5858

5959
def _print_Equality(self, expr):
6060
PREC = precedence(expr)

‎expr_codegen/pandas/ta.py

+11
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,16 @@
77
88
所以有必要使用类似于polars_ta的公共库,但因目前未找到合适库,所以以下是临时版,以后要独立出去
99
"""
10+
from typing import Tuple
11+
1012
import numpy as np
1113
import pandas as pd
1214

15+
try:
16+
import talib
17+
except:
18+
pass
19+
1320

1421
def abs_(x: pd.Series) -> pd.Series:
1522
return np.abs(x)
@@ -93,3 +100,7 @@ def ts_std_dev(x: pd.Series, d: int = 5, ddof: int = 0) -> pd.Series:
93100

94101
def ts_sum(x: pd.Series, d: int = 5) -> pd.Series:
95102
return x.rolling(d).sum()
103+
104+
105+
def ts_MACD(close: pd.Series, fastperiod: int = 12, slowperiod: int = 26, signalperiod: int = 9) -> Tuple[pd.Series, pd.Series, pd.Series]:
106+
return talib.MACD(close, fastperiod, slowperiod, signalperiod)

‎expr_codegen/pandas/template.py.j2

+7-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ import numpy as np # noqa
66
import pandas as pd # noqa
77
from loguru import logger # noqa
88

9+
from expr_codegen.pandas.helper import GlobalVariable
910
from expr_codegen.pandas.ta import * # noqa
1011

11-
1212
{{ syms1 }}
1313

1414
{{ syms2 }}
@@ -19,15 +19,19 @@ _NONE_ = None
1919
_TRUE_ = True
2020
_FALSE_ = False
2121

22+
g = GlobalVariable()
23+
24+
def unpack(x: Tuple, idx: int = 0) -> pd.Series:
25+
return x[idx]
26+
2227
{%-for row in extra_codes %}
2328
{{ row-}}
2429
{% endfor %}
2530

26-
2731
{% for key, value in funcs.items() %}
2832
def {{ key }}(df: pd.DataFrame) -> pd.DataFrame:
2933
{{ value }}
30-
return df
34+
return g.df
3135
{% endfor %}
3236

3337
"""

‎expr_codegen/polars_group/template.py.j2

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ _NONE_ = None
3030
_TRUE_ = True
3131
_FALSE_ = False
3232

33+
def unpack(x: Expr, idx: int = 0) -> Expr:
34+
return x.struct[idx]
35+
3336
{%-for row in extra_codes %}
3437
{{ row-}}
3538
{% endfor %}

‎expr_codegen/polars_over/template.py.j2

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ _NONE_ = None
3030
_TRUE_ = True
3131
_FALSE_ = False
3232

33+
def unpack(x: Expr, idx: int = 0) -> Expr:
34+
return x.struct[idx]
35+
3336
{%-for row in extra_codes %}
3437
{{ row-}}
3538
{% endfor %}

0 commit comments

Comments
 (0)
Please sign in to comment.