Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classifier parameter setting #1

Open
Opdoop opened this issue Feb 27, 2019 · 1 comment
Open

Classifier parameter setting #1

Opdoop opened this issue Feb 27, 2019 · 1 comment
Labels

Comments

@Opdoop
Copy link

Opdoop commented Feb 27, 2019

感谢@hankcs。请问如何使用这段grid search?

@hankcs
Copy link
Owner

hankcs commented Mar 5, 2019

这是一段搜索正则化因子的函数。

/**
 *
 * Liblinear 自动寻参
 * @author hankcs
 */
public class grid
{
    public static double find_parameters(final Problem prob, double from, double end, double step)
    {
        if (from > end)
        {
            double x = end;
            from = end;
            end = x;
        }
        if(step < 0) step = -step;
        final double[] cs = new double[(int) ((end - from) / step)];
        final double[] as = new double[cs.length];
        Linear.setDebugOutput(null);
        ExecutorService fixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        final AtomicInteger finished = new AtomicInteger(0);
        for (int i = 0; i < cs.length; i++)
        {
            cs[i] = from + step * i;
            final int index = i;
            fixedThreadPool.execute(new Runnable()
            {
                public void run()
                {
                    int n = finished.incrementAndGet();
                    as[index] = validate(prob, cs[index]);
                    System.out.printf("%.2f%%...\n", n / (double)cs.length * 100.);
                }
            });
        }
        fixedThreadPool.shutdown();
        try
        {
            fixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
        }
        catch (InterruptedException e)
        {
            e.printStackTrace();
        }
        int p = 0;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < as.length; i++)
        {
            if (as[i] > max)
            {
                max = as[i];
                p = i;
            }
        }
        System.out.printf("Best Cross Validation Accuracy = %g%%, C = %f%n", max * 100, cs[p]);

        return cs[p];
    }

    private static double validate(Problem prob, double C)
    {
        double[] target = new double[prob.l];
        Parameter param = new Parameter(SolverType.L1R_LR, C, 0.01);
        Linear.crossValidation(prob, param, 5, target);
        int total_correct = 0;
        for (int i = 0; i < prob.l; i++)
            if (target[i] == prob.y[i]) ++total_correct;
        return total_correct / (double)prob.l;
    }

    public static void main(String[] args) throws IOException, InvalidInputDataException
    {
        Problem problem = Train.readProblem(new File("libsvm/dataset/heart_scale.txt"), -1);
        System.out.println(find_parameters(problem, 1., 1000., 1.));
    }
}

@hankcs hankcs added the question label Mar 5, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants