1
1
import inspect
2
+ import pathlib
2
3
from functools import lru_cache
3
4
from io import TextIOWrapper
4
5
from typing import Sequence , Dict , Union , TextIO , TypeVar , Optional , Literal
5
6
6
7
from black import Mode , format_str
8
+ from loguru import logger
7
9
from sympy import simplify , cse , symbols , numbered_symbols
8
10
from sympy .core .expr import Expr
9
11
from sympy .logic import boolalg
@@ -262,6 +264,7 @@ def all(self, exprs_src, style: Literal['pandas', 'polars_group', 'polars_over']
262
264
def _get_code (self ,
263
265
source : str , * more_sources : str ,
264
266
extra_codes : str ,
267
+ output_file : str ,
265
268
convert_xor : bool ,
266
269
style : Literal ['pandas' , 'polars_group' , 'polars_over' ] = 'polars_over' , template_file : str = 'template.py.j2' ,
267
270
date : str = 'date' , asset : str = 'asset' ) -> str :
@@ -278,6 +281,16 @@ def _get_code(self,
278
281
extra_codes ,
279
282
))
280
283
284
+ # 移回到cache,防止多次调用多次保存
285
+ if isinstance (output_file , TextIOWrapper ):
286
+ # 输出到控制台
287
+ output_file .write (code )
288
+ elif output_file is not None :
289
+ output_file = pathlib .Path (output_file )
290
+ logger .info (f'save to { output_file .absolute ()} ' )
291
+ with open (output_file , 'w' , encoding = 'utf-8' ) as f :
292
+ f .write (code )
293
+
281
294
return code
282
295
283
296
@@ -287,7 +300,7 @@ def _exec_code(code: str, df_input):
287
300
return globals_ ['df_output' ]
288
301
289
302
290
- def _exec_file (file : str , df_input ):
303
+ def _exec_file (file , df_input ):
291
304
with open (file , 'r' , encoding = 'utf-8' ) as f :
292
305
code = f .read ()
293
306
return _exec_code (code , df_input )
@@ -308,14 +321,15 @@ def codegen_exec(df: Optional[DataFrame],
308
321
output_file : Union [str , TextIO , None ] = None ,
309
322
run_file : Union [bool , str ] = False ,
310
323
convert_xor : bool = False ,
311
- style : Literal ['pandas' , 'polars_group' , 'polars_over' ] = 'polars_over' , template_file : str = 'template.py.j2' ,
324
+ style : Literal ['pandas' , 'polars_group' , 'polars_over' ] = 'polars_over' ,
325
+ template_file : str = 'template.py.j2' ,
312
326
date : str = 'date' , asset : str = 'asset' ,
313
327
) -> Optional [DataFrame ]:
314
328
"""快速转换源代码并执行
315
329
316
330
Parameters
317
331
----------
318
- df: pl.DataFrame or pd.DataFrame
332
+ df: pl.DataFrame, pd.DataFrame, pl.LazyFrame
319
333
输入DataFrame
320
334
codes:
321
335
函数体。此部分中的表达式会被翻译成目标代码
@@ -350,12 +364,17 @@ def codegen_exec(df: Optional[DataFrame],
350
364
if df is not None :
351
365
if run_file is True :
352
366
assert output_file is not None , 'output_file is required'
367
+ output_file = pathlib .Path (output_file )
368
+ logger .info (f'run file "{ output_file .absolute ()} "' )
353
369
return _exec_file (output_file , df )
354
370
if run_file is not False :
355
371
run_file = str (run_file )
356
372
if run_file .endswith ('.py' ):
373
+ run_file = pathlib .Path (run_file )
374
+ logger .info (f'run file "{ run_file .absolute ()} "' )
357
375
return _exec_file (run_file , df )
358
376
else :
377
+ logger .info (f'run module "{ run_file } "' )
359
378
return _exec_module (run_file , df ) # 可断点调试
360
379
361
380
# 此代码来自于sympy.var
@@ -366,20 +385,14 @@ def codegen_exec(df: Optional[DataFrame],
366
385
more_sources = [c if isinstance (c , str ) else inspect .getsource (c ) for c in codes ]
367
386
368
387
code = _TOOL_ ._get_code (
369
- * more_sources , extra_codes = extra_codes ,
388
+ * more_sources ,
389
+ extra_codes = extra_codes ,
390
+ output_file = output_file ,
370
391
convert_xor = convert_xor ,
371
392
style = style , template_file = template_file ,
372
393
date = date , asset = asset ,
373
394
)
374
395
375
- if isinstance (output_file , TextIOWrapper ):
376
- # 输出到控制台
377
- output_file .write (code )
378
- elif output_file is not None :
379
- # 保存到文件
380
- with open (output_file , 'w' , encoding = 'utf-8' ) as f :
381
- f .write (code )
382
-
383
396
if df is None :
384
397
return None
385
398
else :
0 commit comments