diff --git a/datasets/download_text_classification.sh b/datasets/download_text_classification.sh index 3f654d1..1d3e6f3 100755 --- a/datasets/download_text_classification.sh +++ b/datasets/download_text_classification.sh @@ -3,41 +3,46 @@ DIR="./TextClassification" mkdir $DIR cd $DIR -rm -rf mnli -wget --content-disposition https://cloud.tsinghua.edu.cn/f/33182c22cb594e88b49b/?dl=1 -tar -zxvf mnli.tar.gz -rm -rf mnli.tar.gz - -rm -rf agnews -wget --content-disposition https://cloud.tsinghua.edu.cn/f/0fb6af2a1e6647b79098/?dl=1 -tar -zxvf agnews.tar.gz -rm -rf agnews.tar.gz - -rm -rf dbpedia -wget --content-disposition https://cloud.tsinghua.edu.cn/f/362d3cdaa63b4692bafb/?dl=1 -tar -zxvf dbpedia.tar.gz -rm -rf dbpedia.tar.gz - -rm -rf imdb -wget --content-disposition https://cloud.tsinghua.edu.cn/f/37bd6cb978d342db87ed/?dl=1 -tar -zxvf imdb.tar.gz -rm -rf imdb.tar.gz - -rm -rf SST-2 -wget --content-disposition https://cloud.tsinghua.edu.cn/f/bccfdb243eca404f8bf3/?dl=1 -tar -zxvf SST-2.tar.gz -rm -rf SST-2.tar.gz - -rm -rf amazon -wget --content-disposition https://cloud.tsinghua.edu.cn/f/e00a4c44aaf844cdb6c9/?dl=1 -tar -zxvf amazon.tar.gz -mv datasets/amazon/ amazon -rm -rf ./datasets -rm -rf amazon.tar.gz - -rm -rf yahoo_answers_topics -wget --content-disposition https://cloud.tsinghua.edu.cn/f/79257038afaa4730a03f/?dl=1 -tar -zxvf yahoo_answers_topics.tar.gz -rm -rf yahoo_answers_topics.tar.gz +# rm -rf mnli +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/33182c22cb594e88b49b/?dl=1 +# tar -zxvf mnli.tar.gz +# rm -rf mnli.tar.gz + +# rm -rf agnews +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/0fb6af2a1e6647b79098/?dl=1 +# tar -zxvf agnews.tar.gz +# rm -rf agnews.tar.gz + +# rm -rf dbpedia +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/362d3cdaa63b4692bafb/?dl=1 +# tar -zxvf dbpedia.tar.gz +# rm -rf dbpedia.tar.gz + +# rm -rf imdb +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/37bd6cb978d342db87ed/?dl=1 +# tar -zxvf imdb.tar.gz +# rm -rf imdb.tar.gz + +# rm -rf SST-2 +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/bccfdb243eca404f8bf3/?dl=1 +# tar -zxvf SST-2.tar.gz +# rm -rf SST-2.tar.gz + +# rm -rf amazon +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/e00a4c44aaf844cdb6c9/?dl=1 +# tar -zxvf amazon.tar.gz +# mv datasets/amazon/ amazon +# rm -rf ./datasets +# rm -rf amazon.tar.gz + +# rm -rf yahoo_answers_topics +# wget --content-disposition https://cloud.tsinghua.edu.cn/f/79257038afaa4730a03f/?dl=1 +# tar -zxvf yahoo_answers_topics.tar.gz +# rm -rf yahoo_answers_topics.tar.gz + +rm -rf dwmw17 +wget --content-disposition https://raw.githubusercontent.com/t-davidson/hate-speech-and-offensive-language/master/data/labeled_data.csv +mkdir -p dwmw17 +mv labeled_data.csv dwmw17 cd .. diff --git a/openprompt/data_utils/text_classification_dataset.py b/openprompt/data_utils/text_classification_dataset.py index 1dde85c..f9ca7ce 100644 --- a/openprompt/data_utils/text_classification_dataset.py +++ b/openprompt/data_utils/text_classification_dataset.py @@ -17,6 +17,7 @@ import os import json, csv +import pandas as pd from abc import ABC, abstractmethod from collections import defaultdict, Counter from typing import List, Dict, Callable @@ -27,6 +28,38 @@ from openprompt.data_utils.data_processor import DataProcessor +class Dwmw17Processor(DataProcessor): + def __init__(self): + super().__init__() + self.labels = [ "hate speech", "offensive language", "neither" ] + + def get_examples(self, data_dir, split): + df = pd.read_csv(os.path.join(data_dir, 'labeled_data.csv')) + """ + 24783 rows in total, 0: 1430, 1: 19190, 2: 4163 + I will take 50% as training and 50% as testing + """ + train_splits = [ 715, 9595, 2081 ] + examples = [] + for label_idx in range(len(self.labels)): + df_label = df[df['class'] == label_idx] + train_split = train_splits[label_idx] + + tweets = df_label['tweet'].tolist() + indexs = df_label.iloc[:, 0].tolist() + if split == 'train': + tweets_split = tweets[:train_split] + indexs_split = indexs[:train_split] + else: + tweets_split = tweets[train_split:] + indexs_split = indexs[train_split:] + for tweet, index in zip(tweets_split, indexs_split): + examples.append(InputExample( + guid=str(index), text_a=tweet, text_b="", label=label_idx + )) + return examples + + class MnliProcessor(DataProcessor): # TODO Test needed def __init__(self):