我想比较不同的机器学习算法.作为其中的一部分,我需要能够执行网格搜索以获得最佳超参数.但是,我并没有真正想到为每个固定算法和其超参数的固定子集编写单独的网格搜索实现.相反,我希望它看起来更像是在scikit-learn中,但可能没有那么多功能(例如我不需要多个网格)并且用MATLAB编写.
到目前为止,我试图理解尚未编写的逻辑 grid_search.m
function model = grid_search(algo, data, labels, varargin) p = inputParser; % here comes the list of all possible hyperparameters for all algorithms % I will just leave three for brevity addOptional(p, 'kernel_function', {'linear'}); addOptional(p, 'rbf_sigma', {1}); addOptional(p, 'C', {1}); parse(p, algo, data, labels, varargin{:}); names = fieldnames(p.Results); values = struct2cell(p.Results); % a cell array of cell arrays argsize = 2 * length(names); args = cell(1, argsize); args(1 : 2 : argsize) = names; % Now this is the stumbling point. end
对grid_search
函数的调用应如下所示:
m = grid_search('svm', data, labels, 'kernel_function', {'rbf'}, 'C', {[0.1], [1], [10]}, 'rbf_sigma', {[1], [2], [3]}) m = grid_search('knn', data, labels, 'NumNeighbors', {[1], [10]}, 'Distance', {'euclidean', 'cosine'})
然后第一个调用将尝试rbf内核与Constraints和Sigmas的所有组合:
{'rbf', 0.1, 1} {'rbf', 0.1, 2} {'rbf', 0.1, 3} {'rbf', 1, 1} {'rbf', 1, 2} {'rbf', 1, 3} {'rbf', 10, 1} {'rbf', 10, 2} {'rbf', 10, 3}
args
变量背后的想法是它是一个形式的单元格数组,{'name1', 'value1', 'name2', 'value2', ..., 'nameN', 'valueN'}
稍后将传递给相应的算法:algo(data, labels, args{:})
.它的{'name1', 'name2', ..., 'nameN'}
一部分很容易.问题是我不能不知道如何{'value1', 'value2', ..., 'valueN'}
在每一步上创建零件.
我知道每个人都不知道机器学习术语这就是为什么下面是一个自包含的例子:
假设TARDIS的船员可能包括以下类别的生物:
tardis_crew = {{'doctor'}, {'amy', 'clara'}, {'dalek', 'cyberman', 'master'}}
由于Timelord,Companion和Villain总有一个地方,请告诉我如何生成以下单元格数组:
{'Timelord', 'doctor', 'Companion', 'amy', 'Villain', 'dalek'} {'Timelord', 'doctor', 'Companion', 'amy', 'Villain', 'cyberman'} {'Timelord', 'doctor', 'Companion', 'amy', 'Villain', 'master'} {'Timelord', 'doctor', 'Companion', 'clara', 'Villain', 'dalek'} {'Timelord', 'doctor', 'Companion', 'clara', 'Villain', 'cyberman'} {'Timelord', 'doctor', 'Companion', 'clara', 'Villain', 'master'}
解决方案应该是通用的,即如果一个类中的生物数量发生变化或者添加了更多类生物,它应该仍然有效.我非常感谢一步一步的descritption而不是代码.
PS:原版的非剥离 github版本grid_search.m
可能会让你更好地了解我的意思.