Skip to content

Commit

Permalink
Disable More Unsafe GlobalRef Uses (#236)
Browse files Browse the repository at this point in the history
* Preclude assigning to globals

* Integration testing for invalid global ref access

* Integration test cases for invalid global ref usage

* Use helper functionality to throw exception

* Remove redundant test

* Disable using globals with mutable data

* Improve comment

* Remove redundant function

* Permit GlobalRefs to types to be used

* Bump patch

* Revert "Bump patch"

This reverts commit 41eecf7.

* Revert "Permit GlobalRefs to types to be used"

This reverts commit 8ea49bf.

* Revert "Remove redundant function"

This reverts commit fd67497.

* Revert "Improve comment"

This reverts commit 99645fd.

* Revert "Disable using globals with mutable data"

This reverts commit ddbfcf9.

* Bump patch

* Improve docs on global refs

* Run primal for non-differentiable example
  • Loading branch information
willtebbutt authored Aug 29, 2024
1 parent db10fd9 commit ab56094
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.40"
version = "0.2.41"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
39 changes: 33 additions & 6 deletions docs/src/known_limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,39 @@

Tapir.jl has a number of known qualitative limitations, which we document here.

## Mutation of Global Variables

```@meta
DocTestSetup = quote
using Tapir
end
```

While great care is taken in this package to prevent silent errors, this is one edge case that we have yet to provide a satisfactory solution for.
Consider a function of the form:
```jldoctest bad_globalref
julia> const x = Ref(1.0);
julia> function foo(y::Float64)
x[] = y
return x[]
end
foo (generic function with 1 method)
```
`x` is a global variable (if you refer to it in your code, it appears as a `GlobalRef` in the AST or lowered code).
For some technical reasons that are beyond the scope of this section, this package cannot propagate gradient information through `x`.
`foo` is the identity function, so it should have gradient `1.0`.
However, if you differentiate this example, you'll see:
```jldoctest bad_globalref
julia> rule = Tapir.build_rrule(foo, 2.0);
julia> Tapir.value_and_gradient!!(rule, foo, 2.0)
(2.0, (NoTangent(), 0.0))
```
Observe that while it has correctly computed the identity function, the gradient is zero.

The takehome: do not attempt to differentiate functions which modify global state. Uses of globals which does not involve mutating them is fine though.

## Circular References

To a large extent, Tapir.jl does not presently support circular references in an automatic fashion.
Expand Down Expand Up @@ -78,12 +111,6 @@ If we gain evidence that this _is_ often a problem in practice, we'll look into

## Tangent Generation and Pointers

```@meta
DocTestSetup = quote
using Tapir
end
```

_**The Problem**_


Expand Down
7 changes: 7 additions & 0 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,13 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
]
# Expressions which do not require any special treatment.
return ad_stmt_info(line, stmt, nothing)

elseif stmt.head == :(=) && stmt.args[1] isa GlobalRef
msg = "Encountered assignment to global variable: $(stmt.args[1]). " *
"Cannot differentiate through assignments to globals. " *
"Please refactor your code to avoid assigning to a global, for example by " *
"passing the variable in to the function as an argument."
unhandled_feature(msg)
else
# Encountered an expression that we've not seen before.
throw(error("Unrecognised expression $stmt"))
Expand Down
7 changes: 7 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,13 @@ function inlinable_vararg_invoke_call(
return invoke(vararg_test_for_invoke, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...)
end

# build_rrule should error for this function, because it references a non-const global ref.
__x_for_non_const_global_ref::Float64 = 5.0
function non_const_global_ref(y::Float64)
global __x_for_non_const_global_ref = y
return __x_for_non_const_global_ref
end

function generate_test_functions()
return Any[
(false, :allocs, nothing, const_tester),
Expand Down
17 changes: 17 additions & 0 deletions test/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ end
@testset "non-const" begin
global_ref = GlobalRef(S2SGlobals, :non_const_global)
stmt_info = make_ad_stmts!(global_ref, ID(), info)
@test Tapir.TestResources.non_const_global_ref(5.0) == 5.0 # run primal
@test stmt_info isa Tapir.ADStmtInfo
@test Meta.isexpr(last(stmt_info.fwds)[2].stmt, :call)
@test last(stmt_info.fwds)[2].stmt.args[1] == Tapir.__verify_const
Expand All @@ -166,6 +167,12 @@ end
)
end
@testset "Expr" begin
@testset "assignment to GlobalRef" begin
@test_throws(
Tapir.UnhandledLanguageFeatureException,
make_ad_stmts!(Expr(:(=), GlobalRef(Main, :a), 5.0), ID(), info)
)
end
@testset "copyast" begin
stmt = Expr(:copyast, QuoteNode(:(hi)))
ad_stmts = make_ad_stmts!(stmt, ID(), info)
Expand Down Expand Up @@ -243,6 +250,16 @@ end
# @profview(run_many_times(500, f, rule, fwds_args, out))
end

@testset "integration testing for invalid global ref errors" begin
@test_throws(
Tapir.UnhandledLanguageFeatureException,
Tapir.build_rrule(
Tapir.TapirInterpreter(),
Tuple{typeof(Tapir.TestResources.non_const_global_ref), Float64},
)
)
end

# Tests designed to prevent accidentally re-introducing issues which we have fixed.
@testset "regression tests" begin

Expand Down

2 comments on commit ab56094

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/114084

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.41 -m "<description of version>" ab56094b27a33091dc5179d437657b3dc7378efd
git push origin v0.2.41

Please sign in to comment.