Skip to content

Commit a11c1bb

Browse files
ezyangpytorchmergebot
authored andcommitted
Run Black on all of tools/
Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: pytorch#76089 Approved by: https://github.com/albanD
1 parent ae864d4 commit a11c1bb

File tree

79 files changed

+6179
-3741
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+6179
-3741
lines changed

tools/actions_local_runner.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def git(args: List[str]) -> List[str]:
5959
return [line.strip() for line in lines]
6060

6161

62-
def find_changed_files(ref_branch : str = "origin/master") -> List[str]:
62+
def find_changed_files(ref_branch: str = "origin/master") -> List[str]:
6363
untracked = []
6464

6565
for line in git(["status", "--porcelain"]):
@@ -334,7 +334,7 @@ async def full(self) -> CommandResult:
334334
return await shell_cmd(script, env=env)
335335

336336

337-
def changed_files(ref_branch : str = "origin/master") -> Optional[List[str]]:
337+
def changed_files(ref_branch: str = "origin/master") -> Optional[List[str]]:
338338
changed_files: Optional[List[str]] = None
339339
try:
340340
changed_files = sorted(find_changed_files(ref_branch))
@@ -381,9 +381,11 @@ def main() -> None:
381381
"--no-quiet", help="output commands", action="store_true", default=False
382382
)
383383
parser.add_argument("--step", action="append", help="steps to run (in order)")
384-
parser.add_argument("--ref_branch",
385-
default="origin/master",
386-
help="remote/branch used during comparison for --changed-only (default=origin/master")
384+
parser.add_argument(
385+
"--ref_branch",
386+
default="origin/master",
387+
help="remote/branch used during comparison for --changed-only (default=origin/master",
388+
)
387389
args = parser.parse_args()
388390

389391
quiet = not args.no_quiet

tools/amd_build/build_amd.py

+38-29
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,50 @@
44
import os
55
import argparse
66
import sys
7-
sys.path.append(os.path.realpath(os.path.join(
8-
__file__,
9-
os.path.pardir,
10-
os.path.pardir,
11-
os.path.pardir,
12-
'torch',
13-
'utils')))
7+
8+
sys.path.append(
9+
os.path.realpath(
10+
os.path.join(
11+
__file__, os.path.pardir, os.path.pardir, os.path.pardir, "torch", "utils"
12+
)
13+
)
14+
)
1415

1516
from hipify import hipify_python # type: ignore[import]
1617

17-
parser = argparse.ArgumentParser(description='Top-level script for HIPifying, filling in most common parameters')
18+
parser = argparse.ArgumentParser(
19+
description="Top-level script for HIPifying, filling in most common parameters"
20+
)
1821
parser.add_argument(
19-
'--out-of-place-only',
20-
action='store_true',
21-
help="Whether to only run hipify out-of-place on source files")
22+
"--out-of-place-only",
23+
action="store_true",
24+
help="Whether to only run hipify out-of-place on source files",
25+
)
2226

2327
parser.add_argument(
24-
'--project-directory',
28+
"--project-directory",
2529
type=str,
26-
default='',
30+
default="",
2731
help="The root of the project.",
28-
required=False)
32+
required=False,
33+
)
2934

3035
parser.add_argument(
31-
'--output-directory',
36+
"--output-directory",
3237
type=str,
33-
default='',
38+
default="",
3439
help="The directory to store the hipified project",
35-
required=False)
40+
required=False,
41+
)
3642

3743
parser.add_argument(
38-
'--extra-include-dir',
44+
"--extra-include-dir",
3945
type=str,
4046
default=[],
41-
nargs='+',
47+
nargs="+",
4248
help="The list of extra directories in caffe2 to hipify",
43-
required=False)
49+
required=False,
50+
)
4451

4552
args = parser.parse_args()
4653

@@ -93,13 +100,13 @@
93100
for new_dir in args.extra_include_dir:
94101
abs_new_dir = os.path.join(proj_dir, new_dir)
95102
if os.path.exists(abs_new_dir):
96-
new_dir = os.path.join(new_dir, '**/*')
103+
new_dir = os.path.join(new_dir, "**/*")
97104
includes.append(new_dir)
98105

99106
ignores = [
100107
"caffe2/operators/depthwise_3x3_conv_op_cudnn.cu",
101108
"caffe2/operators/pool_op_cudnn.cu",
102-
'*/hip/*',
109+
"*/hip/*",
103110
# These files are compatible with both cuda and hip
104111
"aten/src/ATen/core/*",
105112
"torch/csrc/jit/codegen/cuda/codegen.cpp",
@@ -116,20 +123,21 @@
116123
# Check if the compiler is hip-clang.
117124
def is_hip_clang() -> bool:
118125
try:
119-
hip_path = os.getenv('HIP_PATH', '/opt/rocm/hip')
120-
with open(hip_path + '/lib/.hipInfo') as f:
121-
return 'HIP_COMPILER=clang' in f.read()
126+
hip_path = os.getenv("HIP_PATH", "/opt/rocm/hip")
127+
with open(hip_path + "/lib/.hipInfo") as f:
128+
return "HIP_COMPILER=clang" in f.read()
122129
except IOError:
123130
return False
124131

132+
125133
# TODO Remove once gloo submodule is recent enough to contain upstream fix.
126134
if is_hip_clang():
127135
gloo_cmake_file = "third_party/gloo/cmake/Hip.cmake"
128136
do_write = False
129137
if os.path.exists(gloo_cmake_file):
130138
with open(gloo_cmake_file, "r") as sources:
131139
lines = sources.readlines()
132-
newlines = [line.replace(' hip_hcc ', ' amdhip64 ') for line in lines]
140+
newlines = [line.replace(" hip_hcc ", " amdhip64 ") for line in lines]
133141
if lines == newlines:
134142
print("%s skipped" % gloo_cmake_file)
135143
else:
@@ -143,7 +151,7 @@ def is_hip_clang() -> bool:
143151
do_write = False
144152
with open(gloo_cmake_file, "r") as sources:
145153
lines = sources.readlines()
146-
newlines = [line.replace('RCCL_LIBRARY', 'RCCL_LIBRARY_PATH') for line in lines]
154+
newlines = [line.replace("RCCL_LIBRARY", "RCCL_LIBRARY_PATH") for line in lines]
147155
if lines == newlines:
148156
print("%s skipped" % gloo_cmake_file)
149157
else:
@@ -159,7 +167,7 @@ def is_hip_clang() -> bool:
159167
if os.path.exists(gloo_cmake_file):
160168
with open(gloo_cmake_file, "r") as sources:
161169
lines = sources.readlines()
162-
newlines = [line.replace('HIP_HCC_FLAGS', 'HIP_CLANG_FLAGS') for line in lines]
170+
newlines = [line.replace("HIP_HCC_FLAGS", "HIP_CLANG_FLAGS") for line in lines]
163171
if lines == newlines:
164172
print("%s skipped" % gloo_cmake_file)
165173
else:
@@ -174,4 +182,5 @@ def is_hip_clang() -> bool:
174182
includes=includes,
175183
ignores=ignores,
176184
out_of_place_only=args.out_of_place_only,
177-
hip_clang_launch=is_hip_clang())
185+
hip_clang_launch=is_hip_clang(),
186+
)

tools/autograd/context.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77

88
# Like tools.api.context.with_native_function, but for
99
# NativeFunctionWithDifferentiabilityInfo.
10-
def with_native_function_with_differentiability_info(func: Callable[[NFWDI], T]) -> Callable[[NFWDI], T]:
10+
def with_native_function_with_differentiability_info(
11+
func: Callable[[NFWDI], T]
12+
) -> Callable[[NFWDI], T]:
1113
@functools.wraps(func)
1214
def wrapper(f: NFWDI) -> T:
1315
with native_function_manager(f.func):
1416
return func(f)
17+
1518
return wrapper

tools/autograd/gen_annotated_fn_args.py

+41-27
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,26 @@
2525
from tools.codegen.context import with_native_function
2626
from tools.codegen.model import BaseOperatorName, NativeFunction
2727
import tools.codegen.api.python as python
28-
from .gen_python_functions import should_generate_py_binding, is_py_torch_function, \
29-
is_py_nn_function, is_py_linalg_function, is_py_variable_method, is_py_special_function, \
30-
is_py_fft_function
28+
from .gen_python_functions import (
29+
should_generate_py_binding,
30+
is_py_torch_function,
31+
is_py_nn_function,
32+
is_py_linalg_function,
33+
is_py_variable_method,
34+
is_py_special_function,
35+
is_py_fft_function,
36+
)
37+
3138

3239
def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
3340
native_functions = parse_native_yaml(native_yaml_path).native_functions
3441
mappings = (
35-
(is_py_torch_function, 'torch._C._VariableFunctions'),
36-
(is_py_nn_function, 'torch._C._nn'),
37-
(is_py_linalg_function, 'torch._C._linalg'),
38-
(is_py_special_function, 'torch._C._special'),
39-
(is_py_fft_function, 'torch._C._fft'),
40-
(is_py_variable_method, 'torch.Tensor'),
42+
(is_py_torch_function, "torch._C._VariableFunctions"),
43+
(is_py_nn_function, "torch._C._nn"),
44+
(is_py_linalg_function, "torch._C._linalg"),
45+
(is_py_special_function, "torch._C._special"),
46+
(is_py_fft_function, "torch._C._fft"),
47+
(is_py_variable_method, "torch.Tensor"),
4148
)
4249
annotated_args: List[str] = []
4350
for pred, namespace in mappings:
@@ -48,13 +55,18 @@ def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
4855
groups[f.func.name.name].append(f)
4956
for group in groups.values():
5057
for f in group:
51-
annotated_args.append(f'{namespace}.{gen_annotated_args(f)}')
58+
annotated_args.append(f"{namespace}.{gen_annotated_args(f)}")
5259

53-
template_path = os.path.join(autograd_dir, 'templates')
60+
template_path = os.path.join(autograd_dir, "templates")
5461
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
55-
fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py.in', lambda: {
56-
'annotated_args': textwrap.indent('\n'.join(annotated_args), ' '),
57-
})
62+
fm.write_with_template(
63+
"annotated_fn_args.py",
64+
"annotated_fn_args.py.in",
65+
lambda: {
66+
"annotated_args": textwrap.indent("\n".join(annotated_args), " "),
67+
},
68+
)
69+
5870

5971
@with_native_function
6072
def gen_annotated_args(f: NativeFunction) -> str:
@@ -63,26 +75,28 @@ def gen_annotated_args(f: NativeFunction) -> str:
6375
if arg.default is not None:
6476
continue
6577
out_arg: Dict[str, Any] = {}
66-
out_arg['name'] = arg.name
67-
out_arg['simple_type'] = python.argument_type_str(arg.type, simple_type=True)
78+
out_arg["name"] = arg.name
79+
out_arg["simple_type"] = python.argument_type_str(arg.type, simple_type=True)
6880
size = python.argument_type_size(arg.type)
6981
if size:
70-
out_arg['size'] = size
82+
out_arg["size"] = size
7183
out_args.append(out_arg)
7284

73-
return f'{f.func.name.name}: {repr(out_args)},'
85+
return f"{f.func.name.name}: {repr(out_args)},"
86+
7487

7588
def main() -> None:
76-
parser = argparse.ArgumentParser(
77-
description='Generate annotated_fn_args script')
78-
parser.add_argument('native_functions', metavar='NATIVE',
79-
help='path to native_functions.yaml')
80-
parser.add_argument('out', metavar='OUT',
81-
help='path to output directory')
82-
parser.add_argument('autograd', metavar='AUTOGRAD',
83-
help='path to template directory')
89+
parser = argparse.ArgumentParser(description="Generate annotated_fn_args script")
90+
parser.add_argument(
91+
"native_functions", metavar="NATIVE", help="path to native_functions.yaml"
92+
)
93+
parser.add_argument("out", metavar="OUT", help="path to output directory")
94+
parser.add_argument(
95+
"autograd", metavar="AUTOGRAD", help="path to template directory"
96+
)
8497
args = parser.parse_args()
8598
gen_annotated(args.native_functions, args.out, args.autograd)
8699

87-
if __name__ == '__main__':
100+
101+
if __name__ == "__main__":
88102
main()

0 commit comments

Comments
 (0)