lua – Torch,为什么我的人工神经网络总是预测零?

前端之家收集整理的这篇文章主要介绍了lua – Torch,为什么我的人工神经网络总是预测零?前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。
我在 Linux CentOS 7机器上使用Torch7.
我正在尝试将人工神经网络(ANN)应用于我的数据集,以解决二进制分类问题.我正在使用一个简单的多层感知器.

我正在使用以下火炬包:optim,torch.

问题是我的感知器总是预测零值(被归类为零的元素),我无法理解为什么……

这是我的数据集(“dataset_file.csv”).有34个功能和1个标签目标(最后一列,可能是0或1):

  1. 0.55,1,0.29,0.46,0.67,0.37,0.41,0.08,0.47,0.23,0.13,0.82,0.25,0.04,0.52,0.33,0
  2. 0.65,0.64,0.02,0.32,0.18,0.2,0.38,0.24,0
  3. 0.34,0.5,0.55,0.06,0.15,0.51,0.22,0.6,0.42,1
  4. 0.46,0.14,0.17,0.1,0.94,0.65,0.75,0.3,0
  5. 0.55,0.03,0.16,0.12,0.73,0.54,0.44,0.35,0.11,0
  6. 0.67,0.71,0.74,0.69,0.27,0.61,0.48,1
  7. 0.52,0.21,0.01,0.34,0.85,0.05,0.36,0
  8. 0.58,0.57,0.19,0
  9. 0.66,0.07,0.45,0.92,0
  10. 0.39,0.31,0.81,0
  11. 0.26,0.26,0.43,0
  12. 0.96,0.63,0.86,0.72,0.53,0.4,0.09,0.8,0.28,0
  13. 0.6,0
  14. 0.72,0.78,0.68,0
  15. 0.56,0.56,0.49,0.62,0.76,0.88,1
  16. 0.61,0.58,0
  17. 0.59,0.87,0
  18. 0.74,0.93,0
  19. 0.64,1
  20. 0.36,0.79,0.59,0.7,1

这是我的Torch Lua代码

  1. -- add comma to separate thousands
  2. function comma_value(amount)
  3. local formatted = amount
  4. while true do
  5. formatted,k = string.gsub(formatted,"^(-?%d+)(%d%d%d)",'%1,%2')
  6. if (k==0) then
  7. break
  8. end
  9. end
  10. return formatted
  11. end
  12.  
  13. -- function that computes the confusion matrix
  14. function confusion_matrix(predictionTestVect,truthVect,threshold,printValues)
  15.  
  16. local tp = 0
  17. local tn = 0
  18. local fp = 0
  19. local fn = 0
  20. local MatthewsCC = -2
  21. local accuracy = -2
  22. local arrayFPindices = {}
  23. local arrayFPvalues = {}
  24. local arrayTPvalues = {}
  25. local areaRoc = 0
  26.  
  27. local fpRateVett = {}
  28. local tpRateVett = {}
  29. local precisionVett = {}
  30. local recallVett = {}
  31.  
  32. for i=1,#predictionTestVect do
  33.  
  34. if printValues == true then
  35. io.write("predictionTestVect["..i.."] = ".. round(predictionTestVect[i],4).."\ttruthVect["..i.."] = "..truthVect[i].." ");
  36. io.flush();
  37. end
  38.  
  39. if predictionTestVect[i] >= threshold and truthVect[i] >= threshold then
  40. tp = tp + 1
  41. arrayTPvalues[#arrayTPvalues+1] = predictionTestVect[i]
  42. if printValues == true then print(" TP ") end
  43. elseif predictionTestVect[i] < threshold and truthVect[i] >= threshold then
  44. fn = fn + 1
  45. if printValues == true then print(" FN ") end
  46. elseif predictionTestVect[i] >= threshold and truthVect[i] < threshold then
  47. fp = fp + 1
  48. if printValues == true then print(" FP ") end
  49. arrayFPindices[#arrayFPindices+1] = i;
  50. arrayFPvalues[#arrayFPvalues+1] = predictionTestVect[i]
  51. elseif predictionTestVect[i] < threshold and truthVect[i] < threshold then
  52. tn = tn + 1
  53. if printValues == true then print(" TN ") end
  54. end
  55. end
  56.  
  57. print("TOTAL:")
  58. print(" FN = "..comma_value(fn).." / "..comma_value(tonumber(fn+tp)).."\t (truth == 1) & (prediction < threshold)");
  59. print(" TP = "..comma_value(tp).." / "..comma_value(tonumber(fn+tp)).."\t (truth == 1) & (prediction >= threshold)\n");
  60.  
  61. print(" FP = "..comma_value(fp).." / "..comma_value(tonumber(fp+tn)).."\t (truth == 0) & (prediction >= threshold)");
  62. print(" TN = "..comma_value(tn).." / "..comma_value(tonumber(fp+tn)).."\t (truth == 0) & (prediction < threshold)\n");
  63.  
  64. local continueLabel = true
  65.  
  66. if continueLabel then
  67. upperMCC = (tp*tn) - (fp*fn)
  68. innerSquare = (tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)
  69. lowerMCC = math.sqrt(innerSquare)
  70.  
  71. MatthewsCC = -2
  72. if lowerMCC>0 then MatthewsCC = upperMCC/lowerMCC end
  73. local signedMCC = MatthewsCC
  74. print("signedMCC = "..signedMCC)
  75.  
  76. if MatthewsCC > -2 then print("\n::::\tMatthews correlation coefficient = "..signedMCC.."\t::::\n");
  77. else print("Matthews correlation coefficient = NOT computable"); end
  78.  
  79. accuracy = (tp + tn)/(tp + tn +fn + fp)
  80. print("accuracy = "..round(accuracy,2).. " = (tp + tn) / (tp + tn +fn + fp) \t \t [worst = -1,best = +1]");
  81.  
  82. local f1_score = -2
  83. if (tp+fp+fn)>0 then
  84. f1_score = (2*tp) / (2*tp+fp+fn)
  85. print("f1_score = "..round(f1_score,2).." = (2*tp) / (2*tp+fp+fn) \t [worst = 0,best = 1]");
  86. else
  87. print("f1_score CANNOT be computed because (tp+fp+fn)==0")
  88. end
  89.  
  90. local totalRate = 0
  91. if MatthewsCC > -2 and f1_score > -2 then
  92. totalRate = MatthewsCC + accuracy + f1_score
  93. print("total rate = "..round(totalRate,2).." in [-1,+3] that is "..round((totalRate+1)*100/4,2).."% of possible correctness");
  94. end
  95.  
  96. local numberOfPredictedOnes = tp + fp;
  97. print("numberOfPredictedOnes = (TP + FP) = "..comma_value(numberOfPredictedOnes).." = "..round(numberOfPredictedOnes*100/(tp + tn + fn + fp),2).."%");
  98.  
  99. io.write("\nDiagnosis: ");
  100. if (fn >= tp and (fn+tp)>0) then print("too many FN false negatives"); end
  101. if (fp >= tn and (fp+tn)>0) then print("too many FP false positives"); end
  102.  
  103.  
  104. if (tn > (10*fp) and tp > (10*fn)) then print("Excellent ! ! !");
  105. elseif (tn > (5*fp) and tp > (5*fn)) then print("Very good ! !");
  106. elseif (tn > (2*fp) and tp > (2*fn)) then print("Good !");
  107. elseif (tn >= fp and tp >= fn) then print("Alright");
  108. else print("Baaaad"); end
  109. end
  110.  
  111. return {accuracy,arrayFPindices,arrayFPvalues,MatthewsCC};
  112. end
  113.  
  114.  
  115. -- Permutations
  116. -- tab = {1,2,3,4,5,6,7,8,9,10}
  117. -- permute(tab,10,10)
  118. function permute(tab,n,count)
  119. n = n or #tab
  120. for i = 1,count or n do
  121. local j = math.random(i,n)
  122. tab[i],tab[j] = tab[j],tab[i]
  123. end
  124. return tab
  125. end
  126.  
  127. -- round a real value
  128. function round(num,idp)
  129. local mult = 10^(idp or 0)
  130. return math.floor(num * mult + 0.5) / mult
  131. end
  132.  
  133.  
  134.  
  135. -- ##############################3
  136.  
  137. local profile_vett = {}
  138. local csv = require("csv")
  139. local fileName = "dataset_file.csv"
  140.  
  141. print("Readin' "..tostring(fileName))
  142. local f = csv.open(fileName)
  143. local column_names = {}
  144.  
  145. local j = 0
  146. for fields in f:lines() do
  147.  
  148. if j>0 then
  149. profile_vett[j] = {}
  150. for i,v in ipairs(fields) do
  151. profile_vett[j][i] = tonumber(v);
  152. end
  153. j = j + 1
  154. else
  155. for i,v in ipairs(fields) do
  156. column_names[i] = v
  157. end
  158. j = j + 1
  159. end
  160. end
  161.  
  162. OPTIM_PACKAGE = true
  163. local output_number = 1
  164. THRESHOLD = 0.5 -- ORIGINAL
  165. DROPOUT_FLAG = false
  166. MOMENTUM = false
  167. MOMENTUM_ALPHA = 0.5
  168. MAX_MSE = 4
  169. LEARN_RATE = 0.001
  170. ITERATIONS = 100
  171.  
  172. local hidden_units = 2000
  173. local hidden_layers = 1
  174.  
  175. local hiddenUnitVect = {2000,4000,6000,8000,10000}
  176. -- local hiddenLayerVect = {1,5}
  177. local hiddenLayerVect = {1}
  178.  
  179. local profile_vett_data = {}
  180. local label_vett = {}
  181.  
  182. for i=1,#profile_vett do
  183. profile_vett_data[i] = {}
  184.  
  185. for j=1,#(profile_vett[1]) do
  186. if j<#(profile_vett[1]) then
  187. profile_vett_data[i][j] = profile_vett[i][j]
  188. else
  189. label_vett[i] = profile_vett[i][j]
  190. end
  191. end
  192. end
  193.  
  194. print("Number of value profiles (rows) = "..#profile_vett_data);
  195. print("Number features (columns) = "..#(profile_vett_data[1]));
  196. print("Number of targets (rows) = "..#label_vett);
  197.  
  198. local table_row_outcome = label_vett
  199. local table_rows_vett = profile_vett
  200.  
  201. -- ########################################################
  202.  
  203. -- START
  204.  
  205. local indexVect = {};
  206. for i=1,#table_rows_vett do indexVect[i] = i; end
  207. permutedIndexVect = permute(indexVect,#indexVect,#indexVect);
  208.  
  209. TEST_SET_PERC = 20
  210. local test_set_size = round((TEST_SET_PERC*#table_rows_vett)/100)
  211.  
  212. print("training_set_size = "..(#table_rows_vett-test_set_size).." elements");
  213. print("test_set_size = "..test_set_size.." elements\n");
  214.  
  215. local train_table_row_profile = {}
  216. local test_table_row_profile = {}
  217. local original_test_indexes = {}
  218.  
  219. for i=1,#table_rows_vett do
  220. if i<=(tonumber(#table_rows_vett)-test_set_size) then
  221. train_table_row_profile[#train_table_row_profile+1] = {torch.Tensor(table_rows_vett[permutedIndexVect[i]]),torch.Tensor{table_row_outcome[permutedIndexVect[i]]}}
  222. else
  223.  
  224. original_test_indexes[#original_test_indexes+1] = permutedIndexVect[i];
  225.  
  226. test_table_row_profile[#test_table_row_profile+1] = {torch.Tensor(table_rows_vett[permutedIndexVect[i]]),torch.Tensor{table_row_outcome[permutedIndexVect[i]]}}
  227. end
  228. end
  229.  
  230. require 'nn'
  231. perceptron = nn.Sequential()
  232. input_number = #table_rows_vett[1]
  233.  
  234. perceptron:add(nn.Linear(input_number,hidden_units))
  235. perceptron:add(nn.Sigmoid())
  236. if DROPOUT_FLAG==true then perceptron:add(nn.Dropout()) end
  237.  
  238. for w=1,hidden_layers do
  239. perceptron:add(nn.Linear(hidden_units,hidden_units))
  240. perceptron:add(nn.Sigmoid())
  241. if DROPOUT_FLAG==true then perceptron:add(nn.Dropout()) end
  242. end
  243. perceptron:add(nn.Linear(hidden_units,output_number))
  244.  
  245.  
  246. function train_table_row_profile:size() return #train_table_row_profile end
  247. function test_table_row_profile:size() return #test_table_row_profile end
  248.  
  249.  
  250. -- OPTIMIZATION LOOPS
  251. local MCC_vect = {}
  252.  
  253. for a=1,#hiddenUnitVect do
  254. for b=1,#hiddenLayerVect do
  255.  
  256. local hidden_units = hiddenUnitVect[a]
  257. local hidden_layers = hiddenLayerVect[b]
  258. print("hidden_units = "..hidden_units.."\t output_number = "..output_number.." hidden_layers = "..hidden_layers)
  259.  
  260.  
  261. local criterion = nn.MSECriterion()
  262. local lossSum = 0
  263. local error_progress = 0
  264.  
  265. require 'optim'
  266. local params,gradParams = perceptron:getParameters()
  267. local optimState = nil
  268.  
  269. if MOMENTUM==true then
  270. optimState = {learningRate = LEARN_RATE}
  271. else
  272. optimState = {learningRate = LEARN_RATE,momentum = MOMENTUM_ALPHA }
  273. end
  274.  
  275. local total_runs = ITERATIONS*#train_table_row_profile
  276.  
  277. local loopIterations = 1
  278. for epoch=1,ITERATIONS do
  279. for k=1,#train_table_row_profile do
  280.  
  281. -- Function feval
  282. local function feval(params)
  283. gradParams:zero()
  284.  
  285. local thisProfile = train_table_row_profile[k][1]
  286. local thisLabel = train_table_row_profile[k][2]
  287.  
  288. local thisPrediction = perceptron:forward(thisProfile)
  289. local loss = criterion:forward(thisPrediction,thisLabel)
  290.  
  291. -- print("thisPrediction = "..round(thisPrediction[1],2).." thisLabel = "..thisLabel[1])
  292.  
  293. lossSum = lossSum + loss
  294. error_progress = lossSum*100 / (loopIterations*MAX_MSE)
  295.  
  296. if ((loopIterations*100/total_runs)*10)%10==0 then
  297. io.write("completion: ",round((loopIterations*100/total_runs),2).."%" )
  298. io.write(" (epoch="..epoch..")(element="..k..") loss = "..round(loss,2).." ")
  299. io.write("\terror progress = "..round(error_progress,5).."%\n")
  300. end
  301.  
  302. local dloss_doutput = criterion:backward(thisPrediction,thisLabel)
  303.  
  304. perceptron:backward(thisProfile,dloss_doutput)
  305.  
  306. return loss,gradParams
  307. end
  308. optim.sgd(feval,params,optimState)
  309. loopIterations = loopIterations+1
  310. end
  311. end
  312.  
  313.  
  314. local correctPredictions = 0
  315. local atleastOneTrue = false
  316. local atleastOneFalse = false
  317. local predictionTestVect = {}
  318. local truthVect = {}
  319.  
  320. for i=1,#test_table_row_profile do
  321. local current_label = test_table_row_profile[i][2][1]
  322. local prediction = perceptron:forward(test_table_row_profile[i][1])[1]
  323.  
  324. predictionTestVect[i] = prediction
  325. truthVect[i] = current_label
  326.  
  327. local labelResult = false
  328.  
  329. if current_label >= THRESHOLD and prediction >= THRESHOLD then
  330. labelResult = true
  331. elseif current_label < THRESHOLD and prediction < THRESHOLD then
  332. labelResult = true
  333. end
  334.  
  335. if labelResult==true then correctPredictions = correctPredictions + 1; end
  336.  
  337. print("\nCorrect predictions = "..round(correctPredictions*100/#test_table_row_profile,2).."%")
  338.  
  339. local printValues = false
  340. local output_confusion_matrix = confusion_matrix(predictionTestVect,THRESHOLD,printValues)
  341. end
  342. end

有没有人知道为什么我的脚本只预测零元素?

编辑:我用原始数据集替换了我在脚本中使用的规范化版本

解决方法

当我运行您的原始代码时,我有时会预测所有零,我有时会获得完美的性能.这表明您的原始模型对参数值的初始化非常敏感.

如果我使用种子值torch.manualSeed(0)(所以我们总是有相同的初始化),我每次都会得到完美的表现.但这不是一般的解决方案.

为了获得更全面的改进,我做了以下更改:

>减少隐藏单位的数量.在原始代码中你有一个
单个隐藏层的2000个单位.但是你只有34个输入和
1输出通常你只需要隐藏单位的数量
输入和输出数量之间.我减少了它
50.
>标签是不对称的,只有5/27(19%)的标签是1,所以你应该真正划分列车|测试集,以保持1与0的比率.目前我只是将测试集大小增加到’50’%.
>我也将学习率提高到’0.01′,开启MOMENTUM,并将ITERATIONS增加到200.

当我运行这个模型20次(未播种)时,我获得了19次优异的表现.为了进一步改进,您可以进一步调整超参数.并且还应该使用单独的验证集来查看多个初始化,以选择“最佳”模型(尽管这将进一步细分已经非常小的数据集).

  1. -- add comma to separate thousands
  2. function comma_value(amount)
  3. local formatted = amount
  4. while true do
  5. formatted,v in ipairs(fields) do
  6. column_names[i] = v
  7. end
  8. j = j + 1
  9. end
  10. end
  11.  
  12. OPTIM_PACKAGE = true
  13. local output_number = 1
  14. THRESHOLD = 0.5 -- ORIGINAL
  15. DROPOUT_FLAG = false
  16. MOMENTUM_ALPHA = 0.5
  17. MAX_MSE = 4
  18.  
  19. -- CHANGE: increased learn_rate to 0.01,reduced hidden units to 50,turned momentum on,increased iterations to 200
  20. LEARN_RATE = 0.01
  21. local hidden_units = 50
  22. MOMENTUM = true
  23. ITERATIONS = 200
  24. -------------------------------------
  25.  
  26. local hidden_layers = 1
  27.  
  28. local hiddenUnitVect = {2000,#(profile_vett[1]) do
  29. if j<#(profile_vett[1]) then
  30. profile_vett_data[i][j] = profile_vett[i][j]
  31. else
  32. label_vett[i] = profile_vett[i][j]
  33. end
  34. end
  35. end
  36.  
  37. print("Number of value profiles (rows) = "..#profile_vett_data);
  38. print("Number features (columns) = "..#(profile_vett_data[1]));
  39. print("Number of targets (rows) = "..#label_vett);
  40.  
  41. local table_row_outcome = label_vett
  42. local table_rows_vett = profile_vett
  43.  
  44. -- ########################################################
  45.  
  46. -- START
  47.  
  48. -- Seed random number generator
  49. -- torch.manualSeed(0)
  50.  
  51. local indexVect = {};
  52. for i=1,#indexVect);
  53.  
  54. -- CHANGE: increase test_set to 50%
  55. TEST_SET_PERC = 50
  56. ---------------------------
  57.  
  58. local test_set_size = round((TEST_SET_PERC*#table_rows_vett)/100)
  59.  
  60. print("training_set_size = "..(#table_rows_vett-test_set_size).." elements");
  61. print("test_set_size = "..test_set_size.." elements\n");
  62.  
  63. local train_table_row_profile = {}
  64. local test_table_row_profile = {}
  65. local original_test_indexes = {}
  66.  
  67. for i=1,printValues)
  68. end
  69. end
  70. end

下面粘贴的是20次运行中的1次输出

  1. Correct predictions = 100%
  2. TOTAL:
  3. FN = 0 / 4 (truth == 1) & (prediction < threshold)
  4. TP = 4 / 4 (truth == 1) & (prediction >= threshold)
  5.  
  6. FP = 0 / 9 (truth == 0) & (prediction >= threshold)
  7. TN = 9 / 9 (truth == 0) & (prediction < threshold)
  8.  
  9. signedMCC = 1
  10.  
  11. :::: Matthews correlation coefficient = 1 ::::
  12.  
  13. accuracy = 1 = (tp + tn) / (tp + tn +fn + fp) [worst = -1,best = +1]
  14. f1_score = 1 = (2*tp) / (2*tp+fp+fn) [worst = 0,best = 1]
  15. total rate = 3 in [-1,+3] that is 100% of possible correctness
  16. numberOfPredictedOnes = (TP + FP) = 4 = 30.77%
  17.  
  18. Diagnosis: Excellent ! ! !

猜你在找的Lua相关文章