package symjava.math; import java.util.List; import symjava.matrix.SymVector; import symjava.symbolic.Derivative; import symjava.symbolic.Expr; import symjava.symbolic.Func; import symjava.symbolic.utils.Utils; /** * Gradient of a function or expression * */ public class Grad extends SymVector { public Expr[] args = null; protected Func func = null; /** * Construct an instance directly from data and args * * @param data * @param args */ public Grad(SymVector data, Expr... args) { for(Expr e : data) this.data.add(e); this.args = args; } public Grad(Expr f) { if(f instanceof Func) { if(f.isAbstract()) { this.func = (Func)f; for(Expr x : this.func.args) { data.add(f.diff(x)); } } else { for(Expr x : ((Func)f).args) { data.add(f.diff(x)); } } } else { List args = Utils.extractSymbols(f); for(Expr x : args) { data.add(f.diff(x)); } } } public Grad(Expr f, Expr[] args) { if(f instanceof Func) { if(f.isAbstract()) { this.func = (Func)f; for(Expr x : this.func.args) { data.add(f.diff(x)); } } else { for(Expr x : ((Func)f).args) { data.add(f.diff(x)); } } } else { for(Expr x : args) { data.add(f.diff(x)); } } } /** * Functional Gradient * * @param F * @param fs * @param dfs */ public Grad(Expr F, Expr[] fs, Expr[] dfs) { if(fs.length != dfs.length) throw new IllegalArgumentException(); if(F instanceof Func) { this.func = (Func)F; } for(int i=0; i