trait Op extends AnyRef
Represents an operation in the computational graph
Short outline of reverse autograd from scalar values
y = f1 o f2 o .. o fn
One of these subexpression (f_i) has value w2 and arguments w1
. We can
write dy/dw1 = dy/dw2 * dw2/dw1
. dw2/dw1
is the Jacobian of f_i
at the
current value of w1
. dy/dw2
is the Jacobian of y
wrt to w2
at the
current value of w2
.
The current value of w1
and w2
are computed in a forward pass. The value
dy/dy
is 1 and from this dy/dw2
is recursed in the backward pass. The
Jacobian function of dw2/dw1
is computed symbolically and hard coded.
The anonymous function which Op
s must implement is dy/dw2 => dy/dw2 *
dw2/dw1
. The argument of that function (dy/dw2
) is coming down from the
backward pass. The Op
must implement dy/dw2 * dw2/dw1
.
The shape of dy/dw2
is the shape of the value of the operation (dy/dw2
).
The shape of dy/dw2 * dw2/dw1
is the shape of the parameter variable with
respect which the derivative is taken, i.e. w1
since we are computing
dy/dw1
.
How to implement an operation
// Each concrete realization of the operation corresponds to an instance of an Op // The Op instance holds handles to the input variables (here a, b), to be used in the backward pass // The forward pass is effectively done in the constructor of the Op // The backward pass is triggerd and orchestrated by [[lamp.autograd.Variable.backward]] case class Mult(scope: Scope, a: Variable, b: Variable) extends Op { // List all parameters which support partial derivatives, here both a and b val params = List( // partial derivative of the first argument a.zipBackward { (p, out) => // p is the incoming partial derivative, out is where the result is accumated into // Intermediate tensors are released due to the enclosing Scope.root Scope.root { implicit scope => out += (p * b.value).unbroadcast(a.sizes) } }, // partial derivative of the second argument .. b.zipBackward { (p, out) => Scope.root { implicit scope => out += (p * a.value).unbroadcast(b.sizes) } } ) //The value of this operation, i.e. the forward pass val value = Variable(this, a.value.*(b.value)(scope))(scope) }
- Alphabetic
- By Inheritance
- Op
- AnyRef
- Any
- Hide All
- Show All
- Public
- Protected
Abstract Value Members
- abstract val params: List[(Variable, (STen, STen) => Unit)]
Implementation of the backward pass
Implementation of the backward pass
A list of input variables paired up with an anonymous function computing the respective partial derivative. With the notation in the documentation of the trait lamp.autograd.Op:
dy/dw2 => dy/dw2 * dw2/dw1
. The first argument of the anonymous function is the incoming partial derivative (dy/dw2
), the second argument is the output tensor into which the result (dy/dw2 * dw2/dw1
) is accumulated (added).If the operation does not support computing the partial derivative for some of its arguments, then do not include that argument in this list.
- See also
The documentation on the trait lamp.autograd.Op for more details and example.
- abstract val value: Variable
The value of this operation
Concrete Value Members
- final def !=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def ##: Int
- Definition Classes
- AnyRef → Any
- final def ==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def asInstanceOf[T0]: T0
- Definition Classes
- Any
- def clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.CloneNotSupportedException]) @IntrinsicCandidate() @native()
- final def eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- def equals(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef → Any
- final def getClass(): Class[_ <: AnyRef]
- Definition Classes
- AnyRef → Any
- Annotations
- @IntrinsicCandidate() @native()
- def hashCode(): Int
- Definition Classes
- AnyRef → Any
- Annotations
- @IntrinsicCandidate() @native()
- final def isInstanceOf[T0]: Boolean
- Definition Classes
- Any
- val joinedBackward: Option[(STen) => Unit]
- final def ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- final def notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @IntrinsicCandidate() @native()
- final def notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @IntrinsicCandidate() @native()
- final def synchronized[T0](arg0: => T0): T0
- Definition Classes
- AnyRef
- def toString(): String
- Definition Classes
- AnyRef → Any
- final def wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException]) @native()
- final def wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
Deprecated Value Members
- def finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.Throwable]) @Deprecated
- Deprecated
(Since version 9)