Skip to content

分数计算逻辑似乎有问题导致n_sampling没生效? #58

Open
@gantuo

Description

@gantuo

下面这块代码,我理解是,对于每个问题只取n个sample的第0个的分数的均值作为acc。那么n_sampling>1就没意义了。
evaluate.py#line78

score_mat = []
for sample in samples:
    sample['score'] = scores[idx: idx+len(sample['pred'])]
    assert len(sample['score']) == len(sample['pred'])
    score_mat.append(sample['score'])
    idx += len(sample['pred'])

max_len = max([len(s) for s in score_mat])

for i, s in enumerate(score_mat):
    if len(s) < max_len:
        score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad

# output mean of each column of scores
col_means= np.array(score_mat).mean(axis=0)
mean_score = list(np.round(col_means * 100, decimals=1))

result_json = {
    "num_samples": len(samples),
    "num_scores": len(scores),
    "timeout_samples": timeout_cnt,
    "empty_samples": len([s for s in samples if not s['pred'][-1]]),
    "acc": mean_score[0]
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions