Skip to content

Commit 09a0bf3

Browse files
committed
add interruptibleReader,can interrupt file upload or download
1 parent e2f0ccf commit 09a0bf3

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

internal/pkg/engin/engin.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,13 @@ func (e *Engin) Close() error {
151151
return errs
152152
}
153153

154-
func execInternalCommand(client *ssh.Client, command string) (string, error) {
154+
func execInternalCommand(ctx context.Context, client *ssh.Client, command string) (string, error) {
155155
if client == nil {
156156
command = strings.Join([]string{"rscript.exec", command}, " ")
157157
}
158158
commands := strings.Split(strings.TrimSpace(command), " ")
159159
if iFunc, ok := internalFuncMap[commands[0]]; ok {
160-
output, err := iFunc(client, commands[1:]...)
160+
output, err := iFunc(ctx, client, commands[1:]...)
161161
if err != nil {
162162
return "", err
163163
}
@@ -180,7 +180,7 @@ func execExternalCommand(client *ssh.Client, command string) (string, error) {
180180

181181
func (e *Engin) execCommand(client *ssh.Client, command string) (string, error) {
182182
if strings.HasPrefix(command, prefix) {
183-
return execInternalCommand(client, command)
183+
return execInternalCommand(e.ctx, client, command)
184184
}
185185
return execExternalCommand(client, command)
186186
}

internal/pkg/engin/funcs.go

+24-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package engin
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"io"
@@ -15,7 +16,7 @@ import (
1516

1617
const prefix = "rscript"
1718

18-
type internalFunc func(*ssh.Client, ...string) (string, error)
19+
type internalFunc func(context.Context, *ssh.Client, ...string) (string, error)
1920

2021
var internalFuncMap = map[string]internalFunc{
2122
prefix + ".uploadFile": UploadFileWithFs(nil),
@@ -28,7 +29,7 @@ func SetInternalFunc(name string, f internalFunc) {
2829
}
2930

3031
// remoteFilePath, localPath string
31-
func downloadFile(client *ssh.Client, args ...string) (string, error) {
32+
func downloadFile(ctx context.Context, client *ssh.Client, args ...string) (string, error) {
3233
if len(args) != 2 {
3334
return "", errors.New("args error")
3435
}
@@ -61,7 +62,7 @@ func downloadFile(client *ssh.Client, args ...string) (string, error) {
6162
}
6263
defer localFile.Close()
6364

64-
_, err = io.Copy(localFile, remoteFile)
65+
_, err = io.Copy(localFile, NewInterruptibleReader(ctx, remoteFile))
6566
if err != nil {
6667
return "", fmt.Errorf("io.Copy: %w", err)
6768
}
@@ -70,7 +71,7 @@ func downloadFile(client *ssh.Client, args ...string) (string, error) {
7071
}
7172

7273
func UploadFileWithFs(f fs.FS) internalFunc {
73-
return func(client *ssh.Client, args ...string) (string, error) {
74+
return func(ctx context.Context, client *ssh.Client, args ...string) (string, error) {
7475
if len(args) != 2 {
7576
return "", errors.New("args error")
7677
}
@@ -113,7 +114,7 @@ func UploadFileWithFs(f fs.FS) internalFunc {
113114
}
114115
defer remoteFile.Close()
115116

116-
_, err = io.Copy(remoteFile, localFile)
117+
_, err = io.Copy(remoteFile, NewInterruptibleReader(ctx, localFile))
117118
if err != nil {
118119
return "", fmt.Errorf("io.Copy: %w", err)
119120
}
@@ -122,7 +123,7 @@ func UploadFileWithFs(f fs.FS) internalFunc {
122123
}
123124
}
124125

125-
func execLocalCommand(_ *ssh.Client, args ...string) (string, error) {
126+
func execLocalCommand(ctx context.Context, _ *ssh.Client, args ...string) (string, error) {
126127
var (
127128
name string
128129
cmdArgs []string
@@ -147,3 +148,20 @@ func execLocalCommand(_ *ssh.Client, args ...string) (string, error) {
147148
}
148149
return out.String(), nil
149150
}
151+
152+
type InterruptibleReader func(p []byte) (n int, err error)
153+
154+
func (r InterruptibleReader) Read(p []byte) (n int, err error) {
155+
return r(p)
156+
}
157+
158+
func NewInterruptibleReader(ctx context.Context, r io.Reader) io.Reader {
159+
return InterruptibleReader(func(p []byte) (n int, err error) {
160+
select {
161+
case <-ctx.Done():
162+
return 0, io.EOF
163+
default:
164+
return r.Read(p)
165+
}
166+
})
167+
}

0 commit comments

Comments
 (0)