Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

FluxML/Zygote.jl

Open more actions menu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CI Testing Coverage Dev Docs

] add Zygote

Zygote provides source-to-source automatic differentiation (AD) in Julia, and is the next-gen AD system for the Flux differentiable programming framework. For more details and benchmarks of Zygote's technique, see our paper. You may want to check out Flux for more interesting examples of Zygote usage; the documentation here focuses on internals and advanced AD usage.

Zygote supports Julia 1.6 onwards, but we highly recommend using Julia 1.8 or later.

julia> using Zygote

julia> f(x) = 5x + 3

julia> f(10), f'(10)
(53, 5.0)

julia> @code_llvm f'(10)
define i64 @"julia_#625_38792"(i64) {
top:
  ret i64 5
}

"Source-to-source" means that Zygote hooks into Julia's compiler, and generates the backwards pass for you – as if you had written it by hand.

Zygote supports the flexibility and dynamism of the Julia language, including control flow, recursion, closures, structs, dictionaries, and more. Mutation and exception handling are currently not supported.

julia> fs = Dict("sin" => sin, "cos" => cos, "tan" => tan);

julia> gradient(x -> fs[readline()](x), 1)
sin
0.5403023058681398

Zygote benefits from using the ChainRules.jl ruleset. Custom gradients can be defined by extending the ChainRulesCore.jl's rrule:

julia> using ChainRulesCore

julia> add(a, b) = a + b

julia> function ChainRulesCore.rrule(::typeof(add), a, b)
           add_pb(dy) = (NoTangent(), dy, dy)
           return add(a, b), add_pb
       end

To support large machine learning models with many parameters, Zygote can differentiate implicitly-used parameters, as opposed to just function arguments.

julia> W, b = rand(2, 3), rand(2);

julia> predict(x) = W*x .+ b;

julia> g = gradient(Params([W, b])) do
         sum(predict([1,2,3]))
       end
Grads(...)

julia> g[W], g[b]
([1.0 2.0 3.0; 1.0 2.0 3.0], [1.0, 1.0])
Morty Proxy This is a proxified and sanitized view of the page, visit original site.