Skip to content
This repository has been archived by the owner on Jun 13, 2024. It is now read-only.

Commit

Permalink
🏷️ typing: enhance jit.to_static typing
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Jun 3, 2024
1 parent c9e96d6 commit 10a2d24
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 35 deletions.
1 change: 1 addition & 0 deletions paddle-stubs/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ from __future__ import annotations
# Basic
from .basic import IntSequence as IntSequence
from .basic import NestedNumbericSequence as NestedNumbericSequence
from .basic import NestedSequence as NestedSequence
from .basic import Numberic as Numberic
from .basic import NumbericSequence as NumbericSequence

Expand Down
34 changes: 5 additions & 29 deletions paddle-stubs/_typing/basic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,8 @@ from typing_extensions import TypeAlias

Numberic: TypeAlias = int | float | complex | np.number[Any]

_T = TypeVar("_T", bound=Numberic)
_SeqLevel1: TypeAlias = Sequence[_T]

_TL1 = TypeVar("_TL1", bound=_SeqLevel1[Numberic])
_SeqLevel2: TypeAlias = Sequence[_TL1]

_TL2 = TypeVar("_TL2", bound=_SeqLevel2[_SeqLevel1[Numberic]])
_SeqLevel3: TypeAlias = Sequence[_TL2]

_TL3 = TypeVar("_TL3", bound=_SeqLevel3[_SeqLevel2[_SeqLevel1[Numberic]]])
_SeqLevel4: TypeAlias = Sequence[_TL3]

_TL4 = TypeVar("_TL4", bound=_SeqLevel4[_SeqLevel3[_SeqLevel2[_SeqLevel1[Numberic]]]])
_SeqLevel5: TypeAlias = Sequence[_TL4]

_TL5 = TypeVar("_TL5", bound=_SeqLevel5[_SeqLevel4[_SeqLevel3[_SeqLevel2[_SeqLevel1[Numberic]]]]])
_SeqLevel6: TypeAlias = Sequence[_TL5]

IntSequence: TypeAlias = _SeqLevel1[int]
NumbericSequence: TypeAlias = _SeqLevel1[Numberic]
NestedNumbericSequence: TypeAlias = (
Numberic
| _SeqLevel1[Numberic]
| _SeqLevel2[_SeqLevel1[Numberic]]
| _SeqLevel3[_SeqLevel2[_SeqLevel1[Numberic]]]
| _SeqLevel4[_SeqLevel3[_SeqLevel2[_SeqLevel1[Numberic]]]]
| _SeqLevel5[_SeqLevel4[_SeqLevel3[_SeqLevel2[_SeqLevel1[Numberic]]]]]
| _SeqLevel6[_SeqLevel5[_SeqLevel4[_SeqLevel3[_SeqLevel2[_SeqLevel1[Numberic]]]]]]
)
_T = TypeVar("_T")
NestedSequence = _T | Sequence["NestedSequence[_T]"]
IntSequence = Sequence[int]
NumbericSequence = Sequence[Numberic]
NestedNumbericSequence: TypeAlias = NestedSequence[Numberic]
1 change: 1 addition & 0 deletions paddle-stubs/jit/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from .api import StaticFunction as StaticFunction
from .api import enable_to_static as enable_to_static
from .api import ignore_module as ignore_module
from .api import load as load
Expand Down
34 changes: 28 additions & 6 deletions paddle-stubs/jit/api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@ from __future__ import annotations

from collections.abc import Callable
from types import ModuleType
from typing import Any, TypedDict, TypeVar, Generic, overload
from typing import Any, Generic, Protocol, TypedDict, TypeVar, overload

from typing_extensions import Literal, ParamSpec, TypeAlias, Unpack

from paddle.nn import Layer
from paddle.static import BuildStrategy, InputSpec, Program

from .._typing import NestedSequence
from .translated_layer import TranslatedLayer
from paddle.nn import Layer

_LayerT = TypeVar("_LayerT", bound=Layer)
_RetT = TypeVar("_RetT")
_InputT = ParamSpec("_InputT")
Backends: TypeAlias = Literal["CINN"]


class _SaveLoadConfig(TypedDict):
output_spec: Any
with_hook: Any
Expand All @@ -23,8 +27,10 @@ class _SaveLoadConfig(TypedDict):
input_names_after_prune: Any
skip_prune_program: Any


class ConcreteProgram: ...


class StaticFunction(Generic[_InputT, _RetT]):
def __init__(
self, function: Callable[_InputT, _RetT], input_spec: list[InputSpec] | None = None, **kwargs: Any
Expand Down Expand Up @@ -63,22 +69,38 @@ class StaticFunction(Generic[_InputT, _RetT]):
@property
def function_spec(self) -> Any: ...


class ToStaticDecorator(Protocol):
@overload
def __call__(self, function: _LayerT) -> _LayerT: ...
@overload
def __call__(self, function: Callable[_InputT, _RetT]) -> StaticFunction[_InputT, _RetT]: ...


@overload
def to_static(
function: Layer,
input_spec: InputSpec | None = ...,
function: _LayerT,
input_spec: NestedSequence[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Any,
) -> Layer: ...
) -> _LayerT: ...
@overload
def to_static(
function: Callable[_InputT, _RetT],
input_spec: InputSpec | None = ...,
input_spec: NestedSequence[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Any,
) -> StaticFunction[_InputT, _RetT]: ...
@overload
def to_static(
function: None = ...,
input_spec: NestedSequence[InputSpec] | None = ...,
build_strategy: BuildStrategy | None = ...,
backend: Backends | None = ...,
**kwargs: Any,
) -> ToStaticDecorator: ...
def not_to_static(func: Callable[_InputT, _RetT] | None = None) -> Callable[_InputT, _RetT]: ...
def enable_to_static(enable_to_static_bool: bool) -> None: ...
def ignore_module(modules: list[ModuleType]) -> None: ...
Expand Down
127 changes: 127 additions & 0 deletions tests/test_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# pyright: strict, reportUnusedVariable=false

from __future__ import annotations

from typing_extensions import assert_type

import paddle


def test_import():
paddle.jit.to_static

from paddle.jit import (
to_static, # pyright: ignore [reportUnusedImport]
)


def test_static_net_without_params_1():
class Net(paddle.nn.Layer):
@paddle.jit.to_static
def forward(self, x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return x

net = Net()
assert_type(net, Net)


def test_static_net_without_params_2():
class Net(paddle.nn.Layer):
def forward(self, x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return x

net = Net()
static_net = paddle.jit.to_static(net)
assert_type(static_net, Net)


def test_static_net_with_params_1():
class Net(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc = paddle.nn.Linear(10, 10)

def forward(self, x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return self.fc(x)

net = Net()
static_net = paddle.jit.to_static(net, input_spec=[paddle.static.InputSpec(shape=[None, 10])])
assert_type(static_net, Net)


def test_static_net_with_params_2():
class Net(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc = paddle.nn.Linear(10, 10)

@paddle.jit.to_static(input_spec=[paddle.static.InputSpec(shape=[None, 10])])
def forward(self, x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return self.fc(x)

net = Net()
assert_type(net, Net)


def test_static_net_with_params_3():
class Net(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc = paddle.nn.Linear(10, 10)

def forward(self, x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return self.fc(x)

net = Net()
static_net = paddle.jit.to_static(input_spec=[paddle.static.InputSpec(shape=[None, 10])])(net)
assert_type(static_net, Net)


def test_static_fn_without_params_1():
@paddle.jit.to_static
def fn(x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return x

x = paddle.randn([10, 10])
y = fn(x)
assert_type(y, paddle.Tensor)


def test_static_fn_without_params_2():
def fn(x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return x

static_fn = paddle.jit.to_static(fn)
x = paddle.randn([10, 10])
y = static_fn(x)
assert_type(y, paddle.Tensor)


def test_static_fn_with_params_1():
@paddle.jit.to_static(input_spec=[paddle.static.InputSpec(shape=[None, 10])])
def fn(x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return x

x = paddle.randn([10, 10])
y = fn(x)
assert_type(y, paddle.Tensor)


def test_static_fn_with_params_2():
def fn(x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return x

static_fn = paddle.jit.to_static(input_spec=[paddle.static.InputSpec(shape=[None, 10])])(fn)
x = paddle.randn([10, 10])
y = static_fn(x)
assert_type(y, paddle.Tensor)


def test_static_fn_with_params_3():
def fn(x: paddle.Tensor) -> paddle.Tensor: # type: ignore
return x

static_fn = paddle.jit.to_static(fn, input_spec=[paddle.static.InputSpec(shape=[None, 10])])
x = paddle.randn([10, 10])
y = static_fn(x)
assert_type(y, paddle.Tensor)

0 comments on commit 10a2d24

Please sign in to comment.