/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.loss.Loss;

public class ElasticNetWeightDecay
extends Loss {
    private float lambda1;
    private float lambda2;
    private NDList parameters;

    public ElasticNetWeightDecay(NDList parameters) {
        this("ElasticNetWeightDecay", parameters);
    }

    public ElasticNetWeightDecay(String name, NDList parameters) {
        this(name, parameters, 1.0f);
    }

    public ElasticNetWeightDecay(String name, NDList parameters, float lambda) {
        super(name);
        this.lambda1 = lambda;
        this.lambda2 = lambda;
        this.parameters = parameters;
    }

    public ElasticNetWeightDecay(String name, NDList parameters, float lambda1, float lambda2) {
        super(name);
        this.lambda1 = lambda1;
        this.lambda2 = lambda2;
        this.parameters = parameters;
    }

    private NDArray l1(NDArray w) {
        return w.abs().sum();
    }

    private NDArray l2(NDArray w) {
        return w.square().sum();
    }

    @Override
    public NDArray evaluate(NDList label, NDList prediction) {
        NDManager manager = this.parameters.getManager();
        NDArray sum1 = manager.create(0.0f);
        NDArray sum2 = manager.create(0.0f);
        for (NDArray wi : this.parameters) {
            sum1.addi(this.l1(wi));
            sum2.addi(this.l2(wi));
        }
        return sum1.muli(Float.valueOf(this.lambda1)).addi(sum2.muli(Float.valueOf(this.lambda2)));
    }
}

