Skip to content

Commit ab53cbd

Browse files
committed
fix scope for function expressions and inner functions by capturing a clone
1 parent 86c4686 commit ab53cbd

File tree

7 files changed

+110
-10
lines changed

7 files changed

+110
-10
lines changed

activations/activations.go

+13
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ func (a *Activation[T]) Set(name string, value T) {
107107
a.entries[name] = value
108108
}
109109

110+
func (a *Activation[T]) Clone() *Activation[T] {
111+
clone := NewActivation[T](a.MemoryGauge, a.Parent)
112+
113+
if a.entries != nil {
114+
clone.entries = make(map[string]T, len(a.entries))
115+
for name, value := range a.entries { //nolint:maprange
116+
clone.entries[name] = value
117+
}
118+
}
119+
120+
return clone
121+
}
122+
110123
// Activations is a stack of activation records.
111124
// Each entry represents a new activation record.
112125
//

ast/visitor.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type Element interface {
3030

3131
type StatementDeclarationVisitor[T any] interface {
3232
VisitVariableDeclaration(*VariableDeclaration) T
33-
VisitFunctionDeclaration(*FunctionDeclaration) T
33+
VisitFunctionDeclaration(declaration *FunctionDeclaration, isStatement bool) T
3434
VisitSpecialFunctionDeclaration(*SpecialFunctionDeclaration) T
3535
VisitCompositeDeclaration(*CompositeDeclaration) T
3636
VisitAttachmentDeclaration(*AttachmentDeclaration) T
@@ -68,7 +68,7 @@ func AcceptDeclaration[T any](declaration Declaration, visitor DeclarationVisito
6868
return visitor.VisitVariableDeclaration(declaration.(*VariableDeclaration))
6969

7070
case ElementTypeFunctionDeclaration:
71-
return visitor.VisitFunctionDeclaration(declaration.(*FunctionDeclaration))
71+
return visitor.VisitFunctionDeclaration(declaration.(*FunctionDeclaration), false)
7272

7373
case ElementTypeSpecialFunctionDeclaration:
7474
return visitor.VisitSpecialFunctionDeclaration(declaration.(*SpecialFunctionDeclaration))
@@ -151,7 +151,7 @@ func AcceptStatement[T any](statement Statement, visitor StatementVisitor[T]) (_
151151
return visitor.VisitVariableDeclaration(statement.(*VariableDeclaration))
152152

153153
case ElementTypeFunctionDeclaration:
154-
return visitor.VisitFunctionDeclaration(statement.(*FunctionDeclaration))
154+
return visitor.VisitFunctionDeclaration(statement.(*FunctionDeclaration), true)
155155

156156
case ElementTypeSpecialFunctionDeclaration:
157157
return visitor.VisitSpecialFunctionDeclaration(statement.(*SpecialFunctionDeclaration))

interpreter/interpreter.go

+21-2
Original file line numberDiff line numberDiff line change
@@ -703,10 +703,10 @@ func (interpreter *Interpreter) VisitProgram(program *ast.Program) {
703703
}
704704

705705
func (interpreter *Interpreter) VisitSpecialFunctionDeclaration(declaration *ast.SpecialFunctionDeclaration) StatementResult {
706-
return interpreter.VisitFunctionDeclaration(declaration.FunctionDeclaration)
706+
return interpreter.VisitFunctionDeclaration(declaration.FunctionDeclaration, false)
707707
}
708708

709-
func (interpreter *Interpreter) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration) StatementResult {
709+
func (interpreter *Interpreter) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration, isStatement bool) StatementResult {
710710

711711
identifier := declaration.Identifier.Identifier
712712

@@ -717,6 +717,25 @@ func (interpreter *Interpreter) VisitFunctionDeclaration(declaration *ast.Functi
717717

718718
// lexical scope: variables in functions are bound to what is visible at declaration time
719719
lexicalScope := interpreter.activations.CurrentOrNew()
720+
if isStatement {
721+
// Cloning the current scope ensures that the function can access variables that are visible,
722+
// but not variables which are declared after the function
723+
// (variable declarations mutate the current activation in place).
724+
//
725+
// For example:
726+
//
727+
// fun foo(a: Int): Int {
728+
// fun bar(): Int {
729+
// return a
730+
// // ^ should refer to the `a` parameter of `foo`,
731+
// // not to the `a` variable declared after `bar`
732+
// }
733+
// let a = 2
734+
// return bar()
735+
// }
736+
//
737+
lexicalScope = lexicalScope.Clone()
738+
}
720739

721740
// make the function itself available inside the function
722741
lexicalScope.Set(identifier, variable)

interpreter/interpreter_expression.go

+18-2
Original file line numberDiff line numberDiff line change
@@ -1315,8 +1315,24 @@ func (interpreter *Interpreter) visitEntries(entries []ast.DictionaryEntry) []Di
13151315

13161316
func (interpreter *Interpreter) VisitFunctionExpression(expression *ast.FunctionExpression) Value {
13171317

1318-
// lexical scope: variables in functions are bound to what is visible at declaration time
1319-
lexicalScope := interpreter.activations.CurrentOrNew()
1318+
// lexical scope: variables in functions are bound to what is visible at declaration time.
1319+
// Cloning the current scope ensures that the function can access variables that are visible,
1320+
// but not variables which are declared after the function
1321+
// (variable declarations mutate the current activation in place).
1322+
//
1323+
// For example:
1324+
//
1325+
// fun foo(a: Int): Int {
1326+
// let bar = fun(): Int {
1327+
// return a
1328+
// // ^ should refer to the `a` parameter of `foo`,
1329+
// // not to the `a` variable declared after `bar`
1330+
// }
1331+
// let a = 2
1332+
// return bar()
1333+
// }
1334+
//
1335+
lexicalScope := interpreter.activations.CurrentOrNew().Clone()
13201336

13211337
functionType := interpreter.Program.Elaboration.FunctionExpressionFunctionType(expression)
13221338

interpreter/misc_test.go

+52
Original file line numberDiff line numberDiff line change
@@ -6864,6 +6864,58 @@ func TestInterpretClosure(t *testing.T) {
68646864
)
68656865
}
68666866

6867+
func TestInterpretClosureScopingFunctionExpression(t *testing.T) {
6868+
t.Parallel()
6869+
6870+
inter := parseCheckAndInterpret(t, `
6871+
fun test(a: Int): Int {
6872+
let bar = fun(): Int {
6873+
return a
6874+
}
6875+
let a = 2
6876+
return bar()
6877+
}
6878+
`)
6879+
6880+
value, err := inter.Invoke("test",
6881+
interpreter.NewUnmeteredIntValueFromInt64(1),
6882+
)
6883+
require.NoError(t, err)
6884+
6885+
AssertValuesEqual(
6886+
t,
6887+
inter,
6888+
interpreter.NewUnmeteredIntValueFromInt64(1),
6889+
value,
6890+
)
6891+
}
6892+
6893+
func TestInterpretClosureScopingInnerFunction(t *testing.T) {
6894+
t.Parallel()
6895+
6896+
inter := parseCheckAndInterpret(t, `
6897+
fun test(a: Int): Int {
6898+
fun bar(): Int {
6899+
return a
6900+
}
6901+
let a = 2
6902+
return bar()
6903+
}
6904+
`)
6905+
6906+
value, err := inter.Invoke("test",
6907+
interpreter.NewUnmeteredIntValueFromInt64(1),
6908+
)
6909+
require.NoError(t, err)
6910+
6911+
AssertValuesEqual(
6912+
t,
6913+
inter,
6914+
interpreter.NewUnmeteredIntValueFromInt64(1),
6915+
value,
6916+
)
6917+
}
6918+
68676919
// TestInterpretCompositeFunctionInvocationFromImportingProgram checks
68686920
// that member functions of imported composites can be invoked from an importing program.
68696921
// See https://github.com/dapperlabs/flow-go/issues/838

sema/check_function.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func PurityFromAnnotation(purity ast.FunctionPurity) FunctionPurity {
3131

3232
}
3333

34-
func (checker *Checker) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration) (_ struct{}) {
34+
func (checker *Checker) VisitFunctionDeclaration(declaration *ast.FunctionDeclaration, _ bool) (_ struct{}) {
3535
checker.visitFunctionDeclaration(
3636
declaration,
3737
functionDeclarationOptions{
@@ -46,7 +46,7 @@ func (checker *Checker) VisitFunctionDeclaration(declaration *ast.FunctionDeclar
4646
}
4747

4848
func (checker *Checker) VisitSpecialFunctionDeclaration(declaration *ast.SpecialFunctionDeclaration) struct{} {
49-
return checker.VisitFunctionDeclaration(declaration.FunctionDeclaration)
49+
return checker.VisitFunctionDeclaration(declaration.FunctionDeclaration, false)
5050
}
5151

5252
type functionDeclarationOptions struct {

sema/gen/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func (*generator) VisitVariableDeclaration(_ *ast.VariableDeclaration) struct{}
185185
panic("variable declarations are not supported")
186186
}
187187

188-
func (g *generator) VisitFunctionDeclaration(decl *ast.FunctionDeclaration) (_ struct{}) {
188+
func (g *generator) VisitFunctionDeclaration(decl *ast.FunctionDeclaration, _ bool) (_ struct{}) {
189189
if len(g.typeStack) == 0 {
190190
panic("global function declarations are not supported")
191191
}

0 commit comments

Comments
 (0)