Skip to content

Commit 72d2dba

Browse files
oulgenpytorchmergebot
authored andcommitted
Add None return type to init (pytorch#132335)
Pull Request resolved: pytorch#132335 Approved by: https://github.com/albanD
1 parent 30d7f0b commit 72d2dba

File tree

130 files changed

+295
-295
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

130 files changed

+295
-295
lines changed

test/dynamo/test_minifier.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_cpu_cuda_module_after_dynamo(self):
119119
backend_name = "relu_compile_error_TESTING_ONLY"
120120
run_code = f"""\
121121
class CpuCudaModule(torch.nn.Module):
122-
def __init__(self):
122+
def __init__(self) -> None:
123123
super().__init__()
124124
self.m_x = torch.nn.Linear(20, 20).cuda()
125125
self.m_y = torch.nn.Linear(20, 20)
@@ -149,7 +149,7 @@ def inner(x1, y1):
149149
res.minifier_module(),
150150
"""\
151151
class Repro(torch.nn.Module):
152-
def __init__(self):
152+
def __init__(self) -> None:
153153
super().__init__()
154154
self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda()
155155
self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True)
@@ -204,7 +204,7 @@ def inner(x):
204204
res.repro_module(),
205205
"""\
206206
class Repro(torch.nn.Module):
207-
def __init__(self):
207+
def __init__(self) -> None:
208208
super().__init__()
209209
210210
def forward(self, x_19):

test/inductor/test_minifier.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def inner(x):
122122
res.repro_module(),
123123
"""\
124124
class Repro(torch.nn.Module):
125-
def __init__(self):
125+
def __init__(self) -> None:
126126
super().__init__()
127127
128128
def forward(self, arg0_1):
@@ -138,7 +138,7 @@ def forward(self, arg0_1):
138138
res.repro_module(),
139139
"""\
140140
class Repro(torch.nn.Module):
141-
def __init__(self):
141+
def __init__(self) -> None:
142142
super().__init__()
143143
144144
def forward(self, arg0_1):

torch/_classes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __getattr__(self, attr):
1919
class _Classes(types.ModuleType):
2020
__file__ = "_classes.py"
2121

22-
def __init__(self):
22+
def __init__(self) -> None:
2323
super().__init__("torch.classes")
2424

2525
def __getattr__(self, name):

torch/_decomp/decompositions_for_rng.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class PhiloxState:
7171
trace time.
7272
"""
7373

74-
def __init__(self):
74+
def __init__(self) -> None:
7575
self.reset()
7676

7777
def reset(self):

torch/_dynamo/backends/distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def run_node(self, n: Node) -> Any:
247247
# This gives us the appropriately strided outputs here which will reflect runtime strides.
248248

249249
class FakeifyFirstAOTInvocationGuard:
250-
def __init__(self):
250+
def __init__(self) -> None:
251251
self.tc = torch._guards.TracingContext.try_get()
252252
assert self.tc
253253
torch._guards.TracingContext.try_get().fakify_first_call = True

torch/_dynamo/code_context.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class CodeContextDict:
8-
def __init__(self):
8+
def __init__(self) -> None:
99
self.code_context = ExactWeakKeyDictionary()
1010

1111
def has_context(self, code: types.CodeType):

torch/_dynamo/debug_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def convert(gm):
170170
"""
171171
from torch.nn import *
172172
class Repro(torch.nn.Module):
173-
def __init__(self):
173+
def __init__(self) -> None:
174174
super().__init__()
175175
"""
176176
)
@@ -491,7 +491,7 @@ def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]:
491491

492492

493493
class NopInputReader:
494-
def __init__(self):
494+
def __init__(self) -> None:
495495
self.total = 0
496496

497497
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):

torch/_dynamo/eval_frame.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ def _fn(*args, **kwargs):
497497
wrapper function.
498498
499499
>> class CallableClass:
500-
>> def __init__(self):
500+
>> def __init__(self) -> None:
501501
>> super().__init__()
502502
>> self.relu = torch.nn.ReLU()
503503
>>
@@ -578,7 +578,7 @@ def __reduce__(self):
578578

579579

580580
class RunOnlyContext(_TorchDynamoContext):
581-
def __init__(self):
581+
def __init__(self) -> None:
582582
# cudagraph trees relies on generation increment
583583
def on_enter():
584584
torch._dynamo.mutation_guard.GenerationTracker.generation += 1
@@ -590,7 +590,7 @@ def __reduce__(self):
590590

591591

592592
class DisableContext(_TorchDynamoContext):
593-
def __init__(self):
593+
def __init__(self) -> None:
594594
super().__init__(callback=None)
595595

596596
def __call__(self, fn):

torch/_dynamo/exc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, name):
7474

7575

7676
class ResetRequired(TorchDynamoException):
77-
def __init__(self):
77+
def __init__(self) -> None:
7878
super().__init__(
7979
textwrap.dedent(
8080
"""

torch/_dynamo/profiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def print_missing(stack):
9292
class Profiler:
9393
unique_graphs = 0
9494

95-
def __init__(self):
95+
def __init__(self) -> None:
9696
self.prof = torch.profiler.profile(
9797
activities=[torch.profiler.ProfilerActivity.CPU],
9898
with_stack=should_print_missing(),

torch/_dynamo/variables/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class MutableLocal(MutableLocalBase):
7070
state.
7171
"""
7272

73-
def __init__(self):
73+
def __init__(self) -> None:
7474
super().__init__(MutableLocalSource.Local)
7575

7676
def __hash__(self):

torch/_dynamo/variables/builder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def __eq__(self, other):
274274

275275

276276
class BackwardStateGraphArg(GraphArg):
277-
def __init__(self):
277+
def __init__(self) -> None:
278278
super().__init__(
279279
source=None,
280280
_example=BackwardState(),
@@ -2646,7 +2646,7 @@ class SourcelessBuilder:
26462646
if/else type->VariableTracker trees that were cropping up all over dynamo.
26472647
"""
26482648

2649-
def __init__(self):
2649+
def __init__(self) -> None:
26502650
raise AssertionError("Use SourcelessBuilder.create()")
26512651

26522652
@staticmethod

torch/_export/db/examples/class_method.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class ClassMethod(torch.nn.Module):
1010
def method(cls, x):
1111
return x + 1
1212

13-
def __init__(self):
13+
def __init__(self) -> None:
1414
super().__init__()
1515
self.linear = torch.nn.Linear(4, 2)
1616

torch/_export/db/examples/cond_branch_class_method.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class CondBranchClassMethod(torch.nn.Module):
2626
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
2727
"""
2828

29-
def __init__(self):
29+
def __init__(self) -> None:
3030
super().__init__()
3131
self.subm = MySubModule()
3232

torch/_export/db/examples/model_attr_mutation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class ModelAttrMutation(torch.nn.Module):
88
Attribute mutation is not supported.
99
"""
1010

11-
def __init__(self):
11+
def __init__(self) -> None:
1212
super().__init__()
1313
self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]
1414

torch/_export/db/examples/scalar_output.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class ScalarOutput(torch.nn.Module):
1111
Returning scalar values from the graph is supported, in addition to Tensor
1212
outputs. Symbolic shapes are captured and rank is specialized.
1313
"""
14-
def __init__(self):
14+
def __init__(self) -> None:
1515
super().__init__()
1616

1717
def forward(self, x):

torch/_export/db/examples/specialized_attribute.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class SpecializedAttribute(torch.nn.Module):
1111
Model attributes are specialized.
1212
"""
1313

14-
def __init__(self):
14+
def __init__(self) -> None:
1515
super().__init__()
1616
self.a = "moo"
1717
self.b = 4

torch/_export/passes/lift_constants_pass.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class ConstantAttrMap(collections.abc.MutableMapping):
2424
if that's the case).
2525
"""
2626

27-
def __init__(self):
27+
def __init__(self) -> None:
2828
# Underlying dict that we use to implement this mapping.
2929
self._constant_attrs: Dict[
3030
Union[int, torch.Tensor, FakeScriptObject], List[Any]

torch/_export/serde/serialize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1413,7 +1413,7 @@ class Result:
14131413
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
14141414
example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]]
14151415

1416-
def __init__(self):
1416+
def __init__(self) -> None:
14171417
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
14181418
self.serialized_name_to_meta: Dict[str, MetaType] = {}
14191419
self.graph = torch.fx.Graph()

torch/_functorch/_aot_autograd/schemas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ class SubclassMeta:
602602
# Optional field because we don't compute for inference graphs
603603
grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None
604604

605-
def __init__(self):
605+
def __init__(self) -> None:
606606
# The fields in this class get set after its construction.
607607
pass
608608

torch/_functorch/aot_autograd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ def functional_call(named_params, named_buffers, *args, **kwargs):
878878
)
879879

880880
class AOTModule(nn.Module):
881-
def __init__(self):
881+
def __init__(self) -> None:
882882
super().__init__()
883883
self.orig_module = mod
884884

torch/_functorch/autograd_function.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# We do this by using creating a custom HigherOrderOperator that only functorch
3131
# dispatches specially.
3232
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
33-
def __init__(self):
33+
def __init__(self) -> None:
3434
super().__init__("custom_function_call")
3535

3636
def __call__(self, autograd_function, *args, **kwargs):
@@ -713,7 +713,7 @@ def new_forward(ctx, *args, **kwargs):
713713

714714

715715
class AutogradFunctionApply(HigherOrderOperator):
716-
def __init__(self):
716+
def __init__(self) -> None:
717717
super().__init__("autograd_function_apply")
718718

719719
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):

torch/_guards.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def __eq__(self, other):
427427

428428

429429
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
430-
def __init__(self):
430+
def __init__(self) -> None:
431431
self.nn_modules: Dict[str, Any] = {}
432432

433433
def copy_graphstate(self):
@@ -476,7 +476,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
476476
"autocast_cache_enabled",
477477
}
478478

479-
def __init__(self):
479+
def __init__(self) -> None:
480480
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
481481

482482
def copy_graphstate(self):
@@ -544,7 +544,7 @@ def remove_guards_with_source(self, source):
544544

545545

546546
class GuardsContext(Checkpointable[GuardsCheckpointState]):
547-
def __init__(self):
547+
def __init__(self) -> None:
548548
self.dynamo_guards: GuardsSet = GuardsSet()
549549
self.aotautograd_guards: List[GuardEnvExpr] = []
550550

torch/_higher_order_ops/auto_functionalize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class AutoFunctionalized(HigherOrderOperator):
5454
underscore is to prevent collisions with kwarg names in **kwargs.
5555
"""
5656

57-
def __init__(self):
57+
def __init__(self) -> None:
5858
super().__init__("auto_functionalized")
5959

6060
def __call__(

torch/_higher_order_ops/effects.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class WithEffects(HigherOrderOperator):
5555
per "effect type", which are enumerated in the _EffectType enum.
5656
"""
5757

58-
def __init__(self):
58+
def __init__(self) -> None:
5959
super().__init__("with_effects")
6060

6161
def __call__(

torch/_higher_order_ops/flex_attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __torch_function__(self, func, types, args, kwargs=None):
3838

3939

4040
class FlexAttentionHOP(HigherOrderOperator):
41-
def __init__(self):
41+
def __init__(self) -> None:
4242
super().__init__("flex_attention")
4343

4444
def __call__(
@@ -74,7 +74,7 @@ def __call__(
7474

7575

7676
class FlexAttentionBackwardHOP(HigherOrderOperator):
77-
def __init__(self):
77+
def __init__(self) -> None:
7878
super().__init__("flex_attention_backward")
7979

8080
def __call__(

torch/_higher_order_ops/out_dtype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class OutDtypeOperator(HigherOrderOperator):
4545
3. Cast the output to `out_dtype`
4646
"""
4747

48-
def __init__(self):
48+
def __init__(self) -> None:
4949
super().__init__("out_dtype")
5050
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
5151
# become different (torch._higher_order_ops.out_dtype) which will result

torch/_higher_order_ops/triton_kernel_wrap.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def identify_mutated_tensors(kernel, kwargs):
519519

520520
# Used for wrapping a Triton Kernel
521521
class TritonKernelWrapperMutation(HigherOrderOperator):
522-
def __init__(self):
522+
def __init__(self) -> None:
523523
super().__init__("triton_kernel_wrapper_mutation")
524524

525525

@@ -528,7 +528,7 @@ def __init__(self):
528528

529529
# Used for wrapping a Triton Kernel in a functional manner
530530
class TritonKernelWrapperFunctional(HigherOrderOperator):
531-
def __init__(self):
531+
def __init__(self) -> None:
532532
super().__init__("triton_kernel_wrapper_functional")
533533

534534

torch/_higher_order_ops/while_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
class WhileLoopOp(HigherOrderOperator):
21-
def __init__(self):
21+
def __init__(self) -> None:
2222
super().__init__("while_loop")
2323

2424
def __call__(

0 commit comments

Comments
 (0)