Skip to content

Commit 2f0e3e0

Browse files
committed
优化文件生成提示
1 parent 6662261 commit 2f0e3e0

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

expr_codegen/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.10.4"
1+
__version__ = "0.10.5"

expr_codegen/tool.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import inspect
2+
import pathlib
23
from functools import lru_cache
34
from io import TextIOWrapper
45
from typing import Sequence, Dict, Union, TextIO, TypeVar, Optional, Literal
56

67
from black import Mode, format_str
8+
from loguru import logger
79
from sympy import simplify, cse, symbols, numbered_symbols
810
from sympy.core.expr import Expr
911
from sympy.logic import boolalg
@@ -262,6 +264,7 @@ def all(self, exprs_src, style: Literal['pandas', 'polars_group', 'polars_over']
262264
def _get_code(self,
263265
source: str, *more_sources: str,
264266
extra_codes: str,
267+
output_file: str,
265268
convert_xor: bool,
266269
style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over', template_file: str = 'template.py.j2',
267270
date: str = 'date', asset: str = 'asset') -> str:
@@ -278,6 +281,16 @@ def _get_code(self,
278281
extra_codes,
279282
))
280283

284+
# 移回到cache,防止多次调用多次保存
285+
if isinstance(output_file, TextIOWrapper):
286+
# 输出到控制台
287+
output_file.write(code)
288+
elif output_file is not None:
289+
output_file = pathlib.Path(output_file)
290+
logger.info(f'save to {output_file.absolute()}')
291+
with open(output_file, 'w', encoding='utf-8') as f:
292+
f.write(code)
293+
281294
return code
282295

283296

@@ -287,7 +300,7 @@ def _exec_code(code: str, df_input):
287300
return globals_['df_output']
288301

289302

290-
def _exec_file(file: str, df_input):
303+
def _exec_file(file, df_input):
291304
with open(file, 'r', encoding='utf-8') as f:
292305
code = f.read()
293306
return _exec_code(code, df_input)
@@ -308,14 +321,15 @@ def codegen_exec(df: Optional[DataFrame],
308321
output_file: Union[str, TextIO, None] = None,
309322
run_file: Union[bool, str] = False,
310323
convert_xor: bool = False,
311-
style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over', template_file: str = 'template.py.j2',
324+
style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over',
325+
template_file: str = 'template.py.j2',
312326
date: str = 'date', asset: str = 'asset',
313327
) -> Optional[DataFrame]:
314328
"""快速转换源代码并执行
315329
316330
Parameters
317331
----------
318-
df: pl.DataFrame or pd.DataFrame
332+
df: pl.DataFrame, pd.DataFrame, pl.LazyFrame
319333
输入DataFrame
320334
codes:
321335
函数体。此部分中的表达式会被翻译成目标代码
@@ -350,12 +364,17 @@ def codegen_exec(df: Optional[DataFrame],
350364
if df is not None:
351365
if run_file is True:
352366
assert output_file is not None, 'output_file is required'
367+
output_file = pathlib.Path(output_file)
368+
logger.info(f'run file "{output_file.absolute()}"')
353369
return _exec_file(output_file, df)
354370
if run_file is not False:
355371
run_file = str(run_file)
356372
if run_file.endswith('.py'):
373+
run_file = pathlib.Path(run_file)
374+
logger.info(f'run file "{run_file.absolute()}"')
357375
return _exec_file(run_file, df)
358376
else:
377+
logger.info(f'run module "{run_file}"')
359378
return _exec_module(run_file, df) # 可断点调试
360379

361380
# 此代码来自于sympy.var
@@ -366,20 +385,14 @@ def codegen_exec(df: Optional[DataFrame],
366385
more_sources = [c if isinstance(c, str) else inspect.getsource(c) for c in codes]
367386

368387
code = _TOOL_._get_code(
369-
*more_sources, extra_codes=extra_codes,
388+
*more_sources,
389+
extra_codes=extra_codes,
390+
output_file=output_file,
370391
convert_xor=convert_xor,
371392
style=style, template_file=template_file,
372393
date=date, asset=asset,
373394
)
374395

375-
if isinstance(output_file, TextIOWrapper):
376-
# 输出到控制台
377-
output_file.write(code)
378-
elif output_file is not None:
379-
# 保存到文件
380-
with open(output_file, 'w', encoding='utf-8') as f:
381-
f.write(code)
382-
383396
if df is None:
384397
return None
385398
else:

0 commit comments

Comments
 (0)