Java:微优化数组操作

前端之家收集整理的这篇文章主要介绍了Java:微优化数组操作前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。
我正在尝试制作一个简单的前馈神经网络的 Java端口.
这显然涉及大量的数值计算,所以我试图尽可能优化我的中心循环.结果应在浮点数据类型的范围内正确.

我当前的代码如下(删除错误处理和初始化):

/**
 * Simple implementation of a Feedforward neural network. The network supports
 * including a bias neuron with a constant output of 1.0 and weighted synapses
 * to hidden and output layers.
 * 
 * @author Martin Wiboe
 */
public class FeedForwardNetwork {
private final int outputNeurons;    // No of neurons in output layer
private final int inputNeurons;     // No of neurons in input layer
private int largestLayerNeurons;    // No of neurons in largest layer
private final int numberLayers;     // No of layers
private final int[] neuronCounts;   // Neuron count in each layer,0 is input
                                // layer.
private final float[][][] fWeights; // Weights between neurons.
                                    // fWeight[fromLayer][fromNeuron][toNeuron]
                                    // is the weight from fromNeuron in
                                    // fromLayer to toNeuron in layer
                                    // fromLayer+1.
private float[][] neuronOutput;     // Temporary storage of output from prevIoUs layer


public float[] compute(float[] input) {
    // Copy input values to input layer output
    for (int i = 0; i < inputNeurons; i++) {
        neuronOutput[0][i] = input[i];
    }

    // Loop through layers
    for (int layer = 1; layer < numberLayers; layer++) {

        // Loop over neurons in the layer and determine weighted input sum
        for (int neuron = 0; neuron < neuronCounts[layer]; neuron++) {
            // Bias neuron is the last neuron in the prevIoUs layer
            int biasNeuron = neuronCounts[layer - 1];

            // Get weighted input from bias neuron - output is always 1.0
            float activation = 1.0F * fWeights[layer - 1][biasNeuron][neuron];

            // Get weighted inputs from rest of neurons in prevIoUs layer
            for (int inputNeuron = 0; inputNeuron < biasNeuron; inputNeuron++) {
                activation += neuronOutput[layer-1][inputNeuron] * fWeights[layer - 1][inputNeuron][neuron];
            }

            // Store neuron output for next round of computation
            neuronOutput[layer][neuron] = sigmoid(activation);
        }
    }

    // Return output from network = output from last layer
    float[] result = new float[outputNeurons];
    for (int i = 0; i < outputNeurons; i++)
        result[i] = neuronOutput[numberLayers - 1][i];

    return result;
}

private final static float sigmoid(final float input) {
    return (float) (1.0F / (1.0F + Math.exp(-1.0F * input)));
}
}

我使用-server选项运行JVM,到目前为止,我的代码比类似的C代码慢25%到50%.我可以做些什么来改善这种情况?

谢谢,

马丁·维博

编辑#1:看到大量的回应后,我应该在我们的场景中澄清数字.在典型的运行期间,该方法将被称为约50.000次,具有不同的输入.典型的网络将分别具有数字层= 3层,分别为190,2和1个神经元.因此,最内循环将具有约2 * 191 3 = 385次迭代(当计算层0和1中添加的偏置神经元时)

编辑#1:在实现这个线程中的各种建议之后,我们的实现与C版本一样快(在〜2%以内).感谢所有的帮助!所有的建议都是有帮助的,但由于我只能将一个答案标记为正确的答案,我将把它提供给@Durandal,用于提示数组优化,并且是唯一一个可以预先计算for循环头的值.

解决方法

不考虑实际的数学,Java中的数组索引本身可以是一个性能指标.考虑到Java没有真正的多维数组,而是将它们实现为数组.在最内层循环中,您可以访问多个索引,其中一些索引实际上在该循环中是不变的.阵列访问的一部分可以移动到循环外:
final int[] neuronOutputSlice = neuronOutput[layer - 1];
final int[][] fWeightSlice = fWeights[layer - 1];
for (int inputNeuron = 0; inputNeuron < biasNeuron; inputNeuron++) {
    activation += neuronOutputSlice[inputNeuron] * fWeightsSlice[inputNeuron][neuron];
}

服务器JIT可能执行类似的代码不变运动,唯一的方法是改变和配置它.在客户端JIT这个应该提高性能无论什么.
您可以尝试的另一件事是预先计算循环退出条件,如下所示:

for (int neuron = 0; neuron < neuronCounts[layer]; neuron++) { ... }
// transform to precalculated exit condition (move invariant array access outside loop)
for (int neuron = 0,neuronCount = neuronCounts[layer]; neuron < neuronCount; neuron++) { ... }

再次,JIT可能已经为您做这个,所以如果有帮助的话.

有没有一点可以与1.0F相乘,这让我无法理解?

float activation = 1.0F * fWeights[layer - 1][biasNeuron][neuron];

其他可能会以可读性为代价可能提高速度的东西:手动内联sigmoid()函数(JIT对内联有非常严格的限制,函数可能更大).运行循环(当然不改变结果)可以稍快一点,因为将循环索引与零测试相比,检查一个局部变量便宜一些(最内层循环是一个强大的候选者,但是不要期望输出在所有情况下都是100%相同,因为添加浮点abc可能与acb不同).

原文链接:https://www.f2er.com/java/124232.html

猜你在找的Java相关文章