1
1
package engin
2
2
3
3
import (
4
+ "context"
4
5
"errors"
5
6
"fmt"
6
7
"io"
@@ -15,7 +16,7 @@ import (
15
16
16
17
const prefix = "rscript"
17
18
18
- type internalFunc func (* ssh.Client , ... string ) (string , error )
19
+ type internalFunc func (context. Context , * ssh.Client , ... string ) (string , error )
19
20
20
21
var internalFuncMap = map [string ]internalFunc {
21
22
prefix + ".uploadFile" : UploadFileWithFs (nil ),
@@ -28,7 +29,7 @@ func SetInternalFunc(name string, f internalFunc) {
28
29
}
29
30
30
31
// remoteFilePath, localPath string
31
- func downloadFile (client * ssh.Client , args ... string ) (string , error ) {
32
+ func downloadFile (ctx context. Context , client * ssh.Client , args ... string ) (string , error ) {
32
33
if len (args ) != 2 {
33
34
return "" , errors .New ("args error" )
34
35
}
@@ -61,7 +62,7 @@ func downloadFile(client *ssh.Client, args ...string) (string, error) {
61
62
}
62
63
defer localFile .Close ()
63
64
64
- _ , err = io .Copy (localFile , remoteFile )
65
+ _ , err = io .Copy (localFile , NewInterruptibleReader ( ctx , remoteFile ) )
65
66
if err != nil {
66
67
return "" , fmt .Errorf ("io.Copy: %w" , err )
67
68
}
@@ -70,7 +71,7 @@ func downloadFile(client *ssh.Client, args ...string) (string, error) {
70
71
}
71
72
72
73
func UploadFileWithFs (f fs.FS ) internalFunc {
73
- return func (client * ssh.Client , args ... string ) (string , error ) {
74
+ return func (ctx context. Context , client * ssh.Client , args ... string ) (string , error ) {
74
75
if len (args ) != 2 {
75
76
return "" , errors .New ("args error" )
76
77
}
@@ -113,7 +114,7 @@ func UploadFileWithFs(f fs.FS) internalFunc {
113
114
}
114
115
defer remoteFile .Close ()
115
116
116
- _ , err = io .Copy (remoteFile , localFile )
117
+ _ , err = io .Copy (remoteFile , NewInterruptibleReader ( ctx , localFile ) )
117
118
if err != nil {
118
119
return "" , fmt .Errorf ("io.Copy: %w" , err )
119
120
}
@@ -122,7 +123,7 @@ func UploadFileWithFs(f fs.FS) internalFunc {
122
123
}
123
124
}
124
125
125
- func execLocalCommand (_ * ssh.Client , args ... string ) (string , error ) {
126
+ func execLocalCommand (ctx context. Context , _ * ssh.Client , args ... string ) (string , error ) {
126
127
var (
127
128
name string
128
129
cmdArgs []string
@@ -147,3 +148,20 @@ func execLocalCommand(_ *ssh.Client, args ...string) (string, error) {
147
148
}
148
149
return out .String (), nil
149
150
}
151
+
152
+ type InterruptibleReader func (p []byte ) (n int , err error )
153
+
154
+ func (r InterruptibleReader ) Read (p []byte ) (n int , err error ) {
155
+ return r (p )
156
+ }
157
+
158
+ func NewInterruptibleReader (ctx context.Context , r io.Reader ) io.Reader {
159
+ return InterruptibleReader (func (p []byte ) (n int , err error ) {
160
+ select {
161
+ case <- ctx .Done ():
162
+ return 0 , io .EOF
163
+ default :
164
+ return r .Read (p )
165
+ }
166
+ })
167
+ }
0 commit comments