1
1
import inspect
2
2
from functools import lru_cache
3
3
from io import TextIOWrapper
4
- from typing import Sequence , Dict , Optional
4
+ from typing import Sequence , Dict , Union , TextIO , TypeVar , Optional , Literal
5
5
6
6
from black import Mode , format_str
7
7
from sympy import simplify , cse , symbols , numbered_symbols
12
12
from expr_codegen .expr import get_current_by_prefix , get_children , replace_exprs
13
13
from expr_codegen .model import dag_start , dag_end , dag_middle
14
14
15
+ try :
16
+ from pandas import DataFrame as _pd_DataFrame
17
+ except ImportError :
18
+ _pd_DataFrame = None
19
+
20
+ try :
21
+ from polars import DataFrame as _pl_DataFrame
22
+ except ImportError :
23
+ _pl_DataFrame = None
24
+
25
+ DataFrame = TypeVar ('DataFrame' , _pl_DataFrame , _pd_DataFrame )
26
+
15
27
# ===============================
16
28
# TypeError: expecting bool or Boolean, not `ts_delay(X, 3)`.
17
29
# ts_delay(X, 3) & ts_delay(Y, 3)
@@ -172,7 +184,7 @@ def dag(self, merge: bool, date, asset):
172
184
G = dag_middle (G , self .exprs_names , self .get_current_func , self .get_current_func_kwargs , date , asset )
173
185
return dag_end (G )
174
186
175
- def all (self , exprs_src , style : str = 'polars_over' , template_file : str = 'template.py.j2' ,
187
+ def all (self , exprs_src , style : Literal [ 'pandas' , 'polars_group' , 'polars_over' ] = 'polars_over' , template_file : str = 'template.py.j2' ,
176
188
replace : bool = True , regroup : bool = False , format : bool = True ,
177
189
date = 'date' , asset = 'asset' ,
178
190
alias : Dict [str , str ] = {},
@@ -207,7 +219,7 @@ def all(self, exprs_src, style: str = 'polars_over', template_file: str = 'templ
207
219
代码字符串
208
220
209
221
"""
210
- assert style in ('polars_group ' , 'polars_over ' , 'pandas ' )
222
+ assert style in ('pandas ' , 'polars_group ' , 'polars_over ' )
211
223
212
224
if replace :
213
225
exprs_src = replace_exprs (exprs_src )
@@ -260,8 +272,8 @@ def _get_code(self,
260
272
source : str , * more_sources : str ,
261
273
extra_codes : str , output_file : str ,
262
274
convert_xor : bool ,
263
- style = 'polars_over' , template_file = 'template.py.j2' ,
264
- date = 'date' , asset = 'asset' ) -> str :
275
+ style : Literal [ 'pandas' , 'polars_group' , 'polars_over' ] = 'polars_over' , template_file : str = 'template.py.j2' ,
276
+ date : str = 'date' , asset : str = 'asset' ) -> str :
265
277
"""通过字符串生成代码, 加了缓存,多次调用不重复生成"""
266
278
raw , exprs_dict = sources_to_exprs (self .globals_ , source , * more_sources , convert_xor = convert_xor )
267
279
@@ -286,19 +298,19 @@ def _get_code(self,
286
298
_TOOL_ = ExprTool ()
287
299
288
300
289
- def codegen_exec (df ,
301
+ def codegen_exec (df : Optional [ DataFrame ] ,
290
302
* codes ,
291
303
extra_codes : str = r'CS_SW_L1 = r"^sw_l1_\d+$"' ,
292
- output_file : Optional [str ] = None ,
304
+ output_file : Union [str , TextIO , None ] = None ,
293
305
convert_xor : bool = False ,
294
- style : str = 'polars_over' , template_file : str = 'template.py.j2' ,
306
+ style : Literal [ 'pandas' , 'polars_group' , 'polars_over' ] = 'polars_over' , template_file : str = 'template.py.j2' ,
295
307
date : str = 'date' , asset : str = 'asset'
296
- ):
308
+ ) -> Optional [ DataFrame ] :
297
309
"""快速转换源代码并执行
298
310
299
311
Parameters
300
312
----------
301
- df: pl.DataFrame
313
+ df: pl.DataFrame or pd.DataFrame
302
314
输入DataFrame
303
315
codes:
304
316
函数体。此部分中的表达式会被翻译成目标代码
@@ -309,7 +321,7 @@ def codegen_exec(df,
309
321
convert_xor: bool
310
322
^ 转成异或还是乘方
311
323
style: str
312
- 代码风格。可选值 ('polars_group ', 'polars_over ', 'pandas')
324
+ 代码风格。可选值 'pandas ', 'polars_group ', 'polars_over'
313
325
template_file: str
314
326
代码模板
315
327
date: str
@@ -338,6 +350,6 @@ def codegen_exec(df,
338
350
)
339
351
340
352
if df is None :
341
- return df
353
+ return None
342
354
else :
343
355
return _TOOL_ .exec (code , df )
0 commit comments