.. currentmodule:: mxnet.autograd
The autograd
package enables automatic
differentiation of NDArray operations.
In machine learning applications,
autograd
is often used to calculate the gradients
of loss functions with respect to parameters.
autograd
records computation history on the fly to calculate gradients later.
This is only enabled inside a with autograd.record():
block.
A with auto_grad.pause()
block can be used inside a record()
block
to temporarily disable recording.
To compute gradient with respect to an NDArray
x
, first call x.attach_grad()
to allocate space for the gradient. Then, start a with autograd.record()
block,
and do some computation. Finally, call backward()
on the result:
>>> x = mx.nd.array([1,2,3,4])
>>> x.attach_grad()
>>> with mx.autograd.record():
... y = x * x + 1
>>> y.backward()
>>> print(x.grad)
[ 2. 4. 6. 8.]
<NDArray 4 @cpu(0)>
Some operators (Dropout, BatchNorm, etc) behave differently in
training and making predictions.
This can be controlled with train_mode
and predict_mode
scope.
By default, MXNet is in predict_mode
.
A with autograd.record()
block by default turns on train_mode
(equivalent to with autograd.record(train_mode=True)
).
To compute a gradient in prediction mode (as when generating adversarial examples),
call record with train_mode=False
and then call backward(train_mode=False)
Although training usually coincides with recording,
this isn't always the case.
To control training vs predict_mode without changing
recording vs not recording,
use a with autograd.train_mode():
or with autograd.predict_mode():
block.
Detailed tutorials are available in Part 1 of the MXNet gluon book.
.. autosummary::
:nosignatures:
record
pause
train_mode
predict_mode
backward
set_training
is_training
set_recording
is_recording
mark_variables
Function
.. automodule:: mxnet.autograd
:members:
Can you improve this documentation? These fine people already did:
Sheng Zha, Aaron Markham, Aston Zhang, Eric Junyuan Xie & Yao WangEdit on GitHub
cljdoc is a website building & hosting documentation for Clojure/Script libraries
× close