-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Classifier parameter setting #1
Labels
Comments
这是一段搜索正则化因子的函数。 /**
*
* Liblinear 自动寻参
* @author hankcs
*/
public class grid
{
public static double find_parameters(final Problem prob, double from, double end, double step)
{
if (from > end)
{
double x = end;
from = end;
end = x;
}
if(step < 0) step = -step;
final double[] cs = new double[(int) ((end - from) / step)];
final double[] as = new double[cs.length];
Linear.setDebugOutput(null);
ExecutorService fixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final AtomicInteger finished = new AtomicInteger(0);
for (int i = 0; i < cs.length; i++)
{
cs[i] = from + step * i;
final int index = i;
fixedThreadPool.execute(new Runnable()
{
public void run()
{
int n = finished.incrementAndGet();
as[index] = validate(prob, cs[index]);
System.out.printf("%.2f%%...\n", n / (double)cs.length * 100.);
}
});
}
fixedThreadPool.shutdown();
try
{
fixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
}
catch (InterruptedException e)
{
e.printStackTrace();
}
int p = 0;
double max = Double.NEGATIVE_INFINITY;
for (int i = 0; i < as.length; i++)
{
if (as[i] > max)
{
max = as[i];
p = i;
}
}
System.out.printf("Best Cross Validation Accuracy = %g%%, C = %f%n", max * 100, cs[p]);
return cs[p];
}
private static double validate(Problem prob, double C)
{
double[] target = new double[prob.l];
Parameter param = new Parameter(SolverType.L1R_LR, C, 0.01);
Linear.crossValidation(prob, param, 5, target);
int total_correct = 0;
for (int i = 0; i < prob.l; i++)
if (target[i] == prob.y[i]) ++total_correct;
return total_correct / (double)prob.l;
}
public static void main(String[] args) throws IOException, InvalidInputDataException
{
Problem problem = Train.readProblem(new File("libsvm/dataset/heart_scale.txt"), -1);
System.out.println(find_parameters(problem, 1., 1000., 1.));
}
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
感谢@hankcs。请问如何使用这段grid search?
text-classification-svm/src/main/java/com/hankcs/hanlp/classification/classifiers/LinearSVMClassifier.java
Line 96 in be8c654
The text was updated successfully, but these errors were encountered: