本公開涉及協作學習,尤其涉及一種高效協作學習的訓練方法及裝置。
背景技術:
1、隨著人工智能技術逐漸成為人們日常生活的重要組成部分,推動其向前發展的數據資源逐漸成為社會生產的基礎性資源。在人工智能模型的訓練中,豐富而多樣化的數據可以為模型注入更多的信息,從而促使產出的模型具備更全面、更強大的事務處理能力。然而,當前的數據資源分布在個人、企業以及政府單位等不同的社會主體中,其形成的數據孤島效應限制了數據的數量以及數據的多樣性,從根本上限制了人工智能算法的能力提升。為了解決這一問題,人們開始引入多方協作的訓練方法,作為數據要素流通的重要手段之一,協作學習試圖打破數據孤島,進一步釋放數據要素的潛在價值。
2、但是,打破數據孤島需要將多方數據信息進行整理、傳輸和聚合,因此會給互聯網傳輸資源的使用帶來巨大的壓力。特別是參與節點數量增加時,這種傳輸需求巨大與傳輸資源稀缺的矛盾會變得更加嚴峻,這進一步降低了高效協作學習的訓練效率。
技術實現思路
1、本公開旨在至少在一定程度上解決相關技術中的技術問題之一。
2、為此,本公開第一方面實施例提出了一種高效協作學習的訓練方法,所述協作學習的參與方包括聚合節點和多個本地節點,所述方法包括以下步驟:
3、s1,通過所述聚合節點將全局模型的初始模型參數下發至每個所述本地節點,以確定每個所述本地節點的本地模型,所述本地模型的參數為所述初始模型參數;
4、s2,利用各個所述本地節點上的本地樣本數據計算對應本地模型的梯度均值向量和梯度向量的協方差矩陣;
5、s3,根據多個所述本地模型的梯度均值向量和梯度向量的協方差矩陣計算全局梯度均值向量和全局協方差矩陣;
6、s4,根據所述全局梯度均值向量、所述全局協方差矩陣以及各個所述本地模型的梯度均值向量和梯度向量的協方差矩陣,確定各個所述本地節點的最優本地更新次數;
7、s5,每個所述本地節點根據各自的所述最優本地更新次數進行本地更新,并將每個所述本地節點更新后的模型參數上傳至所述聚合節點;
8、s6,所述聚合節點根據所述多個本地節點更新后的模型參數確定所述全局模型的模型參數。
9、在本公開一些實施例中,根據所述全局梯度均值向量、所述全局協方差矩陣以及各個所述本地模型的梯度均值向量和梯度向量的協方差矩陣,通過以下公式確定各個所述本地節點的最優本地更新次數:
10、
11、
12、其中,為第個本地節點的最優本地次數比,為全局梯度均值向量,為全局協方差矩陣,為第個本地節點上本地模型的梯度均值向量,為第個本地節點上本地模型梯度向量的協方差矩陣,為批大小,為全局更新步長,為第個本地節點的最優本地更新次數。
13、在本公開一些實施例中,所述聚合節點根據所述多個本地節點更新后的模型參數確定所述全局模型的模型參數,包括:
14、s61,所述聚合節點利用聚合算法對所述多個本地節點更新后的模型參數進行聚合運算,獲得所述全局模型的第一模型參數;
15、s62,確定當前所述聚合節點與多個所述本地節點的交互輪次;
16、s63,響應于所述交互輪次達到最大交互輪次,將所述第一模型參數確定為所述全局模型的模型參數;或者,
17、s64,響應于所述交互輪次未達到所述最大交互輪次,將所述第一模型參數作為新的初始模型參數,重復執行步驟s1-s5和步驟s61進行訓練,直至所述交互輪次達到所述最大交互輪次,將所述第一模型參數確定為所述全局模型的模型參數。
18、在本公開一些實施例中,該方法還包括:將每個所述本地節點的最優本地更新次數與質量評估閾值進行對比;將所述最優本地更新次數大于或等于所述質量評估閾值的本地節點確認為優質節點,將所述最優本地更新次數小于所述質量評估閾值的本地節點確認為非優質節點。
19、本公開第二方面實施例提出了一種高效協作學習的訓練裝置,所述協作學習的參與方包括聚合節點和多個本地節點,包括:
20、第一確定模塊,用于通過所述聚合節點將全局模型的初始模型參數下發至每個所述本地節點,以確定每個所述本地節點的本地模型,所述本地模型的參數為所述初始模型參數;
21、第二確定模塊,用于利用各個所述本地節點上的本地樣本數據計算對應本地模型的梯度均值向量和梯度向量的協方差矩陣;
22、第三確定模塊,用于根據多個所述本地模型的梯度均值向量和梯度向量的協方差矩陣計算全局梯度均值向量和全局協方差矩陣;
23、第四確定模塊,用于根據所述全局梯度均值向量、所述全局協方差矩陣以及各個所述本地模型的梯度均值向量和梯度向量的協方差矩陣,確定各個所述本地節點的最優本地更新次數;
24、更新模塊,用于每個所述本地節點根據各自的所述最優本地更新次數進行本地更新,并將每個所述本地節點更新后的模型參數上傳至所述聚合節點;
25、第五確定模塊,用于所述聚合節點根據所述多個本地節點更新后的模型參數確定所述全局模型的模型參數。
26、在本公開一些實施例中,所述第四確定模塊具體用于:根據所述全局梯度均值向量、所述全局協方差矩陣以及各個所述本地模型的梯度均值向量和梯度向量的協方差矩陣,通過以下公式確定各個所述本地節點的最優本地更新次數:
27、
28、
29、其中,為第個本地節點的最優本地次數比,為全局梯度均值向量,為全局協方差矩陣,為第個本地節點上本地模型的梯度均值向量,為第個本地節點上本地模型梯度向量的協方差矩陣,為批大小,為全局更新步長,為第個本地節點的最優本地更新次數。
30、在本公開一些實施例中,所述第五確定模塊包括:
31、獲取單元,用于在所述聚合節點利用聚合算法對所述多個本地節點更新后的模型參數進行聚合運算,獲得所述全局模型的第一模型參數;
32、第一確定單元,用于確定當前所述聚合節點與多個所述本地節點的交互輪次;
33、第二確定單元,用于響應于所述交互輪次達到最大交互輪次,將所述第一模型參數確定為所述全局模型的模型參數。
34、在本公開一些實施例中,該裝置還包括質量評估模塊;所述質量評估模塊用于:將每個所述本地節點的最優本地更新次數與質量評估閾值進行對比;將所述最優本地更新次數大于或等于所述質量評估閾值的本地節點確認為優質節點,將所述最優本地更新次數小于所述質量評估閾值的本地節點確認為非優質節點。
35、本公開第三方面實施例提出了一種電子設備,包括:處理器,以及與所述處理器通信連接的存儲器;
36、所述存儲器存儲計算機執行指令;
37、所述處理器執行所述存儲器存儲的計算機執行指令,以實現前述第一方面所述的方法。
38、本公開第四方面實施例提出了一種計算機可讀存儲介質,其特征在于,所述計算機可讀存儲介質中存儲有計算機執行指令,所述計算機執行指令被處理器執行時用于實現前述第一方面所述的方法。
39、本公開提供的高效協作學習的訓練方法,利用全局梯度均值向量、全局協方差矩陣以及各個本地模型的梯度均值向量和梯度向量的協方差矩陣,為不同信息豐富程度的本地節點分配最優本地更新次數。利用最優本地更新次數進行更新的本地模型能夠攜帶更多有效信息上傳給聚合節點,由此提升信息傳輸效率,優化梯度信噪比,提高互聯網傳輸資源的使用效率,加速協作學習的模型收斂速度。
40、本公開附加的方面和優點將在下面的描述中部分給出,部分將從下面的描述中變得明顯,或通過本公開的實踐了解到。