Skip to content

Commit 82bbdc7

Browse files
committed
添加type hints
1 parent 413a43a commit 82bbdc7

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

expr_codegen/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.10.0"
1+
__version__ = "0.10.1"

expr_codegen/tool.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
from functools import lru_cache
33
from io import TextIOWrapper
4-
from typing import Sequence, Dict, Optional
4+
from typing import Sequence, Dict, Union, TextIO, TypeVar, Optional, Literal
55

66
from black import Mode, format_str
77
from sympy import simplify, cse, symbols, numbered_symbols
@@ -12,6 +12,18 @@
1212
from expr_codegen.expr import get_current_by_prefix, get_children, replace_exprs
1313
from expr_codegen.model import dag_start, dag_end, dag_middle
1414

15+
try:
16+
from pandas import DataFrame as _pd_DataFrame
17+
except ImportError:
18+
_pd_DataFrame = None
19+
20+
try:
21+
from polars import DataFrame as _pl_DataFrame
22+
except ImportError:
23+
_pl_DataFrame = None
24+
25+
DataFrame = TypeVar('DataFrame', _pl_DataFrame, _pd_DataFrame)
26+
1527
# ===============================
1628
# TypeError: expecting bool or Boolean, not `ts_delay(X, 3)`.
1729
# ts_delay(X, 3) & ts_delay(Y, 3)
@@ -172,7 +184,7 @@ def dag(self, merge: bool, date, asset):
172184
G = dag_middle(G, self.exprs_names, self.get_current_func, self.get_current_func_kwargs, date, asset)
173185
return dag_end(G)
174186

175-
def all(self, exprs_src, style: str = 'polars_over', template_file: str = 'template.py.j2',
187+
def all(self, exprs_src, style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over', template_file: str = 'template.py.j2',
176188
replace: bool = True, regroup: bool = False, format: bool = True,
177189
date='date', asset='asset',
178190
alias: Dict[str, str] = {},
@@ -207,7 +219,7 @@ def all(self, exprs_src, style: str = 'polars_over', template_file: str = 'templ
207219
代码字符串
208220
209221
"""
210-
assert style in ('polars_group', 'polars_over', 'pandas')
222+
assert style in ('pandas', 'polars_group', 'polars_over')
211223

212224
if replace:
213225
exprs_src = replace_exprs(exprs_src)
@@ -260,8 +272,8 @@ def _get_code(self,
260272
source: str, *more_sources: str,
261273
extra_codes: str, output_file: str,
262274
convert_xor: bool,
263-
style='polars_over', template_file='template.py.j2',
264-
date='date', asset='asset') -> str:
275+
style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over', template_file: str = 'template.py.j2',
276+
date: str = 'date', asset: str = 'asset') -> str:
265277
"""通过字符串生成代码, 加了缓存,多次调用不重复生成"""
266278
raw, exprs_dict = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor)
267279

@@ -286,19 +298,19 @@ def _get_code(self,
286298
_TOOL_ = ExprTool()
287299

288300

289-
def codegen_exec(df,
301+
def codegen_exec(df: Optional[DataFrame],
290302
*codes,
291303
extra_codes: str = r'CS_SW_L1 = r"^sw_l1_\d+$"',
292-
output_file: Optional[str] = None,
304+
output_file: Union[str, TextIO, None] = None,
293305
convert_xor: bool = False,
294-
style: str = 'polars_over', template_file: str = 'template.py.j2',
306+
style: Literal['pandas', 'polars_group', 'polars_over'] = 'polars_over', template_file: str = 'template.py.j2',
295307
date: str = 'date', asset: str = 'asset'
296-
):
308+
) -> Optional[DataFrame]:
297309
"""快速转换源代码并执行
298310
299311
Parameters
300312
----------
301-
df: pl.DataFrame
313+
df: pl.DataFrame or pd.DataFrame
302314
输入DataFrame
303315
codes:
304316
函数体。此部分中的表达式会被翻译成目标代码
@@ -309,7 +321,7 @@ def codegen_exec(df,
309321
convert_xor: bool
310322
^ 转成异或还是乘方
311323
style: str
312-
代码风格。可选值 ('polars_group', 'polars_over', 'pandas')
324+
代码风格。可选值 'pandas', 'polars_group', 'polars_over'
313325
template_file: str
314326
代码模板
315327
date: str
@@ -338,6 +350,6 @@ def codegen_exec(df,
338350
)
339351

340352
if df is None:
341-
return df
353+
return None
342354
else:
343355
return _TOOL_.exec(code, df)

0 commit comments

Comments
 (0)