package deepboof.impl.backward.standard;

import com.google.firebase.remoteconfig.FirebaseRemoteConfig;
import deepboof.backward.DSpatialBatchNorm;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

/* loaded from: classes3.dex */
public class DSpatialBatchNorm_F64 extends BaseDBatchNorm_F64 implements DSpatialBatchNorm<Tensor_F64> {
    double M;
    double M_var;
    int numChannels;
    int numPixels;

    public DSpatialBatchNorm_F64(boolean z) {
        super(z);
    }

    private void applyGammaBeta(Tensor_F64 tensor_F64) {
        int i = tensor_F64.startIndex;
        int i2 = 0;
        for (int i3 = 0; i3 < this.miniBatchSize; i3++) {
            for (int i4 = 0; i4 < this.numChannels; i4++) {
                int i5 = i4 * 2;
                double d = this.params.d[i5];
                double d2 = this.params.d[i5 + 1];
                int i6 = 0;
                while (i6 < this.numPixels) {
                    tensor_F64.d[i] = (this.tensorXhat.d[i2] * d) + d2;
                    i6++;
                    i++;
                    i2++;
                }
            }
        }
    }

    private void computeStatisticsAndNormalize(Tensor_F64 tensor_F64) {
        this.tensorMean.zero();
        this.tensorStd.zero();
        this.tensorXhat.zero();
        int i = tensor_F64.startIndex;
        for (int i2 = 0; i2 < this.miniBatchSize; i2++) {
            for (int i3 = 0; i3 < this.numChannels; i3++) {
                double d = 0.0d;
                int i4 = 0;
                while (i4 < this.numPixels) {
                    d += tensor_F64.d[i];
                    i4++;
                    i++;
                }
                double[] dArr = this.tensorMean.d;
                dArr[i3] = dArr[i3] + d;
            }
        }
        for (int i5 = 0; i5 < this.numChannels; i5++) {
            double[] dArr2 = this.tensorMean.d;
            dArr2[i5] = dArr2[i5] / this.M;
        }
        int i6 = tensor_F64.startIndex;
        int i7 = 0;
        for (int i8 = 0; i8 < this.miniBatchSize; i8++) {
            for (int i9 = 0; i9 < this.numChannels; i9++) {
                double d2 = this.tensorMean.d[i9];
                double d3 = 0.0d;
                int i10 = 0;
                while (i10 < this.numPixels) {
                    double d4 = tensor_F64.d[i6] - d2;
                    this.tensorDiffX.d[i7] = d4;
                    d3 += d4 * d4;
                    i10++;
                    i7++;
                    i6++;
                }
                double[] dArr3 = this.tensorStd.d;
                dArr3[i9] = dArr3[i9] + d3;
            }
        }
        for (int i11 = 0; i11 < this.numChannels; i11++) {
            this.tensorStd.d[i11] = Math.sqrt((this.tensorStd.d[i11] / this.M_var) + this.EPS);
        }
        int i12 = 0;
        for (int i13 = 0; i13 < this.miniBatchSize; i13++) {
            for (int i14 = 0; i14 < this.numChannels; i14++) {
                double d5 = this.tensorStd.d[i14];
                int i15 = 0;
                while (i15 < this.numPixels) {
                    this.tensorXhat.d[i12] = this.tensorDiffX.d[i12] / d5;
                    i15++;
                    i12++;
                }
            }
        }
    }

    private void forwardLearning(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        computeStatisticsAndNormalize(tensor_F64);
        if (this.requiresGammaBeta) {
            applyGammaBeta(tensor_F642);
        } else {
            tensor_F642.setTo(this.tensorXhat);
        }
    }

    private void partialMean() {
        this.tensorDMean.zero();
        this.tensorTmp.zero();
        int i = 0;
        for (int i2 = 0; i2 < this.miniBatchSize; i2++) {
            for (int i3 = 0; i3 < this.numChannels; i3++) {
                double d = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
                double d2 = 0.0d;
                int i4 = 0;
                while (i4 < this.numPixels) {
                    d += this.tensorDiffX.d[i];
                    d2 -= this.tensorDXhat.d[i];
                    i4++;
                    i++;
                }
                double[] dArr = this.tensorTmp.d;
                dArr[i3] = dArr[i3] + d;
                double[] dArr2 = this.tensorDMean.d;
                dArr2[i3] = dArr2[i3] + d2;
            }
        }
        for (int i5 = 0; i5 < this.numChannels; i5++) {
            double[] dArr3 = this.tensorDMean.d;
            dArr3[i5] = dArr3[i5] / this.tensorStd.d[i5];
            double[] dArr4 = this.tensorDMean.d;
            dArr4[i5] = dArr4[i5] - (((this.tensorDVar.d[i5] * 2.0d) * this.tensorTmp.d[i5]) / this.M_var);
        }
    }

    private void partialParameters(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        tensor_F64.zero();
        int i = tensor_F642.startIndex;
        int i2 = 0;
        for (int i3 = 0; i3 < this.miniBatchSize; i3++) {
            int i4 = 0;
            int i5 = 0;
            while (i4 < this.numChannels) {
                double d = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
                double d2 = 0.0d;
                int i6 = 0;
                while (i6 < this.numPixels) {
                    double d3 = tensor_F642.d[i];
                    d += this.tensorXhat.d[i2] * d3;
                    d2 += d3;
                    i6++;
                    i2++;
                    i++;
                }
                double[] dArr = tensor_F64.d;
                int i7 = i5 + 1;
                dArr[i5] = dArr[i5] + d;
                double[] dArr2 = tensor_F64.d;
                dArr2[i7] = dArr2[i7] + d2;
                i4++;
                i5 = i7 + 1;
            }
        }
    }

    private void partialVariance() {
        this.tensorDVar.zero();
        int i = 0;
        for (int i2 = 0; i2 < this.miniBatchSize; i2++) {
            for (int i3 = 0; i3 < this.numChannels; i3++) {
                double d = FirebaseRemoteConfig.DEFAULT_VALUE_FOR_DOUBLE;
                int i4 = 0;
                while (i4 < this.numPixels) {
                    d += this.tensorDXhat.d[i] * this.tensorDiffX.d[i];
                    i4++;
                    i++;
                }
                double[] dArr = this.tensorDVar.d;
                dArr[i3] = dArr[i3] + d;
            }
        }
        for (int i5 = 0; i5 < this.numChannels; i5++) {
            double d2 = this.tensorStd.d[i5];
            double[] dArr2 = this.tensorDVar.d;
            dArr2[i5] = dArr2[i5] / (((d2 * d2) * d2) * (-2.0d));
        }
    }

    private void partialX(Tensor_F64 tensor_F64) {
        int i = tensor_F64.startIndex;
        int i2 = 0;
        int i3 = 0;
        while (i2 < this.miniBatchSize) {
            for (int i4 = 0; i4 < this.numChannels; i4++) {
                double d = this.tensorStd.d[i4];
                double d2 = this.tensorDVar.d[i4];
                double d3 = this.tensorDMean.d[i4];
                int i5 = 0;
                while (i5 < this.numPixels) {
                    tensor_F64.d[i] = (this.tensorDXhat.d[i3] / d) + (((2.0d * d2) * this.tensorDiffX.d[i3]) / this.M_var) + (d3 / this.M);
                    i5++;
                    i3++;
                    i++;
                    i2 = i2;
                }
            }
            i2++;
        }
    }

    private void partialXHat(Tensor_F64 tensor_F64) {
        int i = tensor_F64.startIndex;
        int i2 = 0;
        for (int i3 = 0; i3 < this.miniBatchSize; i3++) {
            for (int i4 = 0; i4 < this.numChannels; i4++) {
                double d = this.params.d[i4 * 2];
                int i5 = 0;
                while (i5 < this.numPixels) {
                    this.tensorDXhat.d[i2] = tensor_F64.d[i] * d;
                    i5++;
                    i2++;
                    i++;
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // deepboof.impl.backward.standard.BaseDFunction
    public void _backwards(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642, Tensor_F64 tensor_F643, List<Tensor_F64> list) {
        this.tensorDXhat.reshape(tensor_F64.shape);
        if (this.requiresGammaBeta) {
            partialXHat(tensor_F642);
        } else {
            this.tensorDXhat.setTo(tensor_F642);
        }
        partialVariance();
        partialMean();
        partialX(tensor_F643);
        if (this.requiresGammaBeta) {
            partialParameters(list.get(0), tensor_F642);
        }
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _forward(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        if (tensor_F64.length(0) <= 1) {
            throw new IllegalArgumentException("There must be more than 1 minibatch");
        }
        this.tensorDiffX.reshape(tensor_F64.shape);
        this.tensorXhat.reshape(tensor_F64.shape);
        this.numChannels = tensor_F64.length(1);
        this.numPixels = TensorOps.outerLength(tensor_F64.shape, 2);
        double d = this.miniBatchSize * this.numPixels;
        this.M = d;
        Double.isNaN(d);
        this.M_var = d - 1.0d;
        if (this.learningMode) {
            forwardLearning(tensor_F64, tensor_F642);
        } else {
            forwardEvaluate(tensor_F64, tensor_F642);
        }
    }

    @Override // deepboof.impl.backward.standard.BaseDBatchNorm_F64
    protected int[] createShapeVariables(int[] iArr) {
        return new int[]{iArr[0]};
    }

    public void forwardEvaluate(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        DSpatialBatchNorm_F64 dSpatialBatchNorm_F64 = this;
        int length = tensor_F64.length(1);
        int length2 = tensor_F64.length(2) * tensor_F64.length(3);
        int i = tensor_F64.startIndex;
        int i2 = tensor_F642.startIndex;
        if (!hasGammaBeta()) {
            int i3 = 0;
            while (i3 < dSpatialBatchNorm_F64.miniBatchSize) {
                int i4 = 0;
                while (i4 < length) {
                    double d = dSpatialBatchNorm_F64.tensorMean.d[i4];
                    double d2 = dSpatialBatchNorm_F64.tensorStd.d[i4];
                    int i5 = i + length2;
                    while (i < i5) {
                        tensor_F642.d[i2] = (tensor_F64.d[i] - d) / d2;
                        i2++;
                        i++;
                    }
                    i4++;
                    dSpatialBatchNorm_F64 = this;
                }
                i3++;
                dSpatialBatchNorm_F64 = this;
            }
            return;
        }
        for (int i6 = 0; i6 < dSpatialBatchNorm_F64.miniBatchSize; i6++) {
            int i7 = dSpatialBatchNorm_F64.params.startIndex;
            int i8 = 0;
            while (i8 < length) {
                double d3 = dSpatialBatchNorm_F64.tensorMean.d[i8];
                double d4 = dSpatialBatchNorm_F64.tensorStd.d[i8];
                int i9 = i7 + 1;
                double d5 = dSpatialBatchNorm_F64.params.d[i7];
                int i10 = i9 + 1;
                double d6 = dSpatialBatchNorm_F64.params.d[i9];
                int i11 = i + length2;
                while (i < i11) {
                    tensor_F642.d[i2] = ((tensor_F64.d[i] - d3) * (d5 / d4)) + d6;
                    i2++;
                    i11 = i11;
                    i++;
                }
                i8++;
                i7 = i10;
            }
        }
    }
}
