Skip to content

Commit f70844b

Browse files
ezyangpytorchmergebot
authored andcommittedJul 27, 2023
Enable UFMT on a bunch of low traffic Python files outside of main files (pytorch#106052)
Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#106052 Approved by: https://github.com/albanD, https://github.com/Skylion007
1 parent 5a114f7 commit f70844b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+2319
-1661
lines changed
 

‎.ci/pytorch/create_test_cert.py

+82-54
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from datetime import datetime, timedelta
22
from tempfile import mkdtemp
3-
from cryptography.hazmat.primitives import serialization
4-
from cryptography.hazmat.primitives.asymmetric import rsa
3+
54
from cryptography import x509
5+
from cryptography.hazmat.primitives import hashes, serialization
6+
from cryptography.hazmat.primitives.asymmetric import rsa
67
from cryptography.x509.oid import NameOID
7-
from cryptography.hazmat.primitives import hashes
88

99
temp_dir = mkdtemp()
1010
print(temp_dir)
@@ -16,81 +16,109 @@ def genrsa(path):
1616
key_size=2048,
1717
)
1818
with open(path, "wb") as f:
19-
f.write(key.private_bytes(
20-
encoding=serialization.Encoding.PEM,
21-
format=serialization.PrivateFormat.TraditionalOpenSSL,
22-
encryption_algorithm=serialization.NoEncryption(),
23-
))
19+
f.write(
20+
key.private_bytes(
21+
encoding=serialization.Encoding.PEM,
22+
format=serialization.PrivateFormat.TraditionalOpenSSL,
23+
encryption_algorithm=serialization.NoEncryption(),
24+
)
25+
)
2426
return key
2527

2628

2729
def create_cert(path, C, ST, L, O, key):
28-
subject = issuer = x509.Name([
29-
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
30-
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
31-
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
32-
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
33-
])
34-
cert = x509.CertificateBuilder().subject_name(
35-
subject
36-
).issuer_name(
37-
issuer
38-
).public_key(
39-
key.public_key()
40-
).serial_number(
41-
x509.random_serial_number()
42-
).not_valid_before(
43-
datetime.utcnow()
44-
).not_valid_after(
45-
# Our certificate will be valid for 10 days
46-
datetime.utcnow() + timedelta(days=10)
47-
).add_extension(
48-
x509.BasicConstraints(ca=True, path_length=None), critical=True,
49-
).sign(key, hashes.SHA256())
30+
subject = issuer = x509.Name(
31+
[
32+
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
33+
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
34+
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
35+
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
36+
]
37+
)
38+
cert = (
39+
x509.CertificateBuilder()
40+
.subject_name(subject)
41+
.issuer_name(issuer)
42+
.public_key(key.public_key())
43+
.serial_number(x509.random_serial_number())
44+
.not_valid_before(datetime.utcnow())
45+
.not_valid_after(
46+
# Our certificate will be valid for 10 days
47+
datetime.utcnow()
48+
+ timedelta(days=10)
49+
)
50+
.add_extension(
51+
x509.BasicConstraints(ca=True, path_length=None),
52+
critical=True,
53+
)
54+
.sign(key, hashes.SHA256())
55+
)
5056
# Write our certificate out to disk.
5157
with open(path, "wb") as f:
5258
f.write(cert.public_bytes(serialization.Encoding.PEM))
5359
return cert
5460

5561

5662
def create_req(path, C, ST, L, O, key):
57-
csr = x509.CertificateSigningRequestBuilder().subject_name(x509.Name([
58-
# Provide various details about who we are.
59-
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
60-
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
61-
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
62-
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
63-
])).sign(key, hashes.SHA256())
63+
csr = (
64+
x509.CertificateSigningRequestBuilder()
65+
.subject_name(
66+
x509.Name(
67+
[
68+
# Provide various details about who we are.
69+
x509.NameAttribute(NameOID.COUNTRY_NAME, C),
70+
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ST),
71+
x509.NameAttribute(NameOID.LOCALITY_NAME, L),
72+
x509.NameAttribute(NameOID.ORGANIZATION_NAME, O),
73+
]
74+
)
75+
)
76+
.sign(key, hashes.SHA256())
77+
)
6478
with open(path, "wb") as f:
6579
f.write(csr.public_bytes(serialization.Encoding.PEM))
6680
return csr
6781

6882

6983
def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key):
70-
cert = x509.CertificateBuilder().subject_name(
71-
csr_cert.subject
72-
).issuer_name(
73-
ca_cert.subject
74-
).public_key(
75-
csr_cert.public_key()
76-
).serial_number(
77-
x509.random_serial_number()
78-
).not_valid_before(
79-
datetime.utcnow()
80-
).not_valid_after(
81-
# Our certificate will be valid for 10 days
82-
datetime.utcnow() + timedelta(days=10)
83-
# Sign our certificate with our private key
84-
).sign(private_ca_key, hashes.SHA256())
84+
cert = (
85+
x509.CertificateBuilder()
86+
.subject_name(csr_cert.subject)
87+
.issuer_name(ca_cert.subject)
88+
.public_key(csr_cert.public_key())
89+
.serial_number(x509.random_serial_number())
90+
.not_valid_before(datetime.utcnow())
91+
.not_valid_after(
92+
# Our certificate will be valid for 10 days
93+
datetime.utcnow()
94+
+ timedelta(days=10)
95+
# Sign our certificate with our private key
96+
)
97+
.sign(private_ca_key, hashes.SHA256())
98+
)
8599
with open(path, "wb") as f:
86100
f.write(cert.public_bytes(serialization.Encoding.PEM))
87101
return cert
88102

89103

90104
ca_key = genrsa(temp_dir + "/ca.key")
91-
ca_cert = create_cert(temp_dir + "/ca.pem", "US", "New York", "New York", "Gloo Certificate Authority", ca_key)
105+
ca_cert = create_cert(
106+
temp_dir + "/ca.pem",
107+
"US",
108+
"New York",
109+
"New York",
110+
"Gloo Certificate Authority",
111+
ca_key,
112+
)
92113

93114
pkey = genrsa(temp_dir + "/pkey.key")
94-
csr = create_req(temp_dir + "/csr.csr", "US", "California", "San Francisco", "Gloo Testing Company", pkey)
115+
csr = create_req(
116+
temp_dir + "/csr.csr",
117+
"US",
118+
"California",
119+
"San Francisco",
120+
"Gloo Testing Company",
121+
pkey,
122+
)
95123

96124
cert = sign_certificate_request(temp_dir + "/cert.pem", csr, ca_cert, ca_key)
+36-25
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,41 @@
1-
import sys
1+
import argparse
22
import json
33
import math
4-
import argparse
4+
import sys
55

66
parser = argparse.ArgumentParser()
7-
parser.add_argument('--test-name', dest='test_name', action='store',
8-
required=True, help='test name')
9-
parser.add_argument('--sample-stats', dest='sample_stats', action='store',
10-
required=True, help='stats from sample')
11-
parser.add_argument('--update', action='store_true',
12-
help='whether to update baseline using stats from sample')
7+
parser.add_argument(
8+
"--test-name", dest="test_name", action="store", required=True, help="test name"
9+
)
10+
parser.add_argument(
11+
"--sample-stats",
12+
dest="sample_stats",
13+
action="store",
14+
required=True,
15+
help="stats from sample",
16+
)
17+
parser.add_argument(
18+
"--update",
19+
action="store_true",
20+
help="whether to update baseline using stats from sample",
21+
)
1322
args = parser.parse_args()
1423

1524
test_name = args.test_name
1625

17-
if 'cpu' in test_name:
18-
backend = 'cpu'
19-
elif 'gpu' in test_name:
20-
backend = 'gpu'
26+
if "cpu" in test_name:
27+
backend = "cpu"
28+
elif "gpu" in test_name:
29+
backend = "gpu"
2130

22-
data_file_path = f'../{backend}_runtime.json'
31+
data_file_path = f"../{backend}_runtime.json"
2332

2433
with open(data_file_path) as data_file:
2534
data = json.load(data_file)
2635

2736
if test_name in data:
28-
mean = float(data[test_name]['mean'])
29-
sigma = float(data[test_name]['sigma'])
37+
mean = float(data[test_name]["mean"])
38+
sigma = float(data[test_name]["sigma"])
3039
else:
3140
# Let the test pass if baseline number doesn't exist
3241
mean = sys.maxsize
@@ -43,37 +52,39 @@
4352

4453
sample_stats_data = json.loads(args.sample_stats)
4554

46-
sample_mean = float(sample_stats_data['mean'])
47-
sample_sigma = float(sample_stats_data['sigma'])
55+
sample_mean = float(sample_stats_data["mean"])
56+
sample_sigma = float(sample_stats_data["sigma"])
4857

4958
print("sample mean: ", sample_mean)
5059
print("sample sigma: ", sample_sigma)
5160

5261
if math.isnan(sample_mean):
53-
raise Exception('''Error: sample mean is NaN''')
62+
raise Exception("""Error: sample mean is NaN""")
5463
elif math.isnan(sample_sigma):
55-
raise Exception('''Error: sample sigma is NaN''')
64+
raise Exception("""Error: sample sigma is NaN""")
5665

5766
z_value = (sample_mean - mean) / sigma
5867

5968
print("z-value: ", z_value)
6069

6170
if z_value >= 3:
62-
raise Exception(f'''\n
71+
raise Exception(
72+
f"""\n
6373
z-value >= 3, there is high chance of perf regression.\n
6474
To reproduce this regression, run
6575
`cd .ci/pytorch/perf_test/ && bash {test_name}.sh` on your local machine
6676
and compare the runtime before/after your code change.
67-
''')
77+
"""
78+
)
6879
else:
6980
print("z-value < 3, no perf regression detected.")
7081
if args.update:
7182
print("We will use these numbers as new baseline.")
72-
new_data_file_path = f'../new_{backend}_runtime.json'
83+
new_data_file_path = f"../new_{backend}_runtime.json"
7384
with open(new_data_file_path) as new_data_file:
7485
new_data = json.load(new_data_file)
7586
new_data[test_name] = {}
76-
new_data[test_name]['mean'] = sample_mean
77-
new_data[test_name]['sigma'] = max(sample_sigma, sample_mean * 0.1)
78-
with open(new_data_file_path, 'w') as new_data_file:
87+
new_data[test_name]["mean"] = sample_mean
88+
new_data[test_name]["sigma"] = max(sample_sigma, sample_mean * 0.1)
89+
with open(new_data_file_path, "w") as new_data_file:
7990
json.dump(new_data, new_data_file, indent=4)

0 commit comments

Comments
 (0)