1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
| import torch.utils.data as data import os import pandas as pd from typing import Tuple,Any
class Dataset(data.Dataset): classes = [ [100, '民生 故事', 'news_story'], [101, '文化 文化', 'news_culture'], [102, '娱乐 娱乐', 'news_entertainment'], [103, '体育 体育', 'news_sports'], [104, '财经 财经', 'news_finance'], [106, '房产 房产', 'news_house'], [107, '汽车 汽车', 'news_car'], [108, '教育 教育', 'news_edu' ], [109, '科技 科技', 'news_tech'], [110, '军事 军事', 'news_military'], [112, '旅游 旅游', 'news_travel'], [113, '国际 国际', 'news_world'], [114, '证券 股票', 'stock'], [115, '农业 三农', 'news_agriculture'], [116, '电竞 游戏', 'news_game'] ] classes_map={ "100":0, "101":1, "102":2, "103":3, "104":4, "106":5, "107":6, "108":7, "109":8, "110":9, "112":10, "113":11, "114":12, "115":13, "116":14, } def __init__( self, root:str, train:bool = True ) -> None: """训练集采用前80%,测试集采用后10%""" super().__init__() self.root = root self.train = train self.df : pd.DataFrame = pd.DataFrame(columns=['label', 'text'])
self._load_data() self.df = self.df.head(5000)
print(f"is_train_data={self.train}, total_len={len(self.df)}")
def _load_data(self): """从{root}/toutiao_cat_data.txt或者{root}/toutiao_cat_data.csv中加载数据到pandas DataFrame中""" if not os.path.exists(os.path.join(self.root, "toutiao_cat_data.csv")): """从txt中加载数据""" with open(os.path.join(self.root,"toutiao_cat_data.txt"),'r') as file: count = 0 for line in file: count = count +1 if(count % 100 == 0): print(f'line:{count}, items={items}') items = line.split('_!_') label, text = items[1], items[3]+","+items[4] self.df = self.df._append({'label': label, 'text': text}, ignore_index=True) self.df.to_csv(os.path.join(self.root,'toutiao_cat_data.csv'), index=False, header=True) print(f"txt total count={count}") total_len = len(self.df) train_len = int(total_len * 0.8) valid_len = int(total_len * 0.1) df_train = self.df.iloc[:train_len] df_valid = self.df.iloc[train_len:train_len+valid_len] df_test = self.df.iloc[train_len+valid_len:] df_train.to_csv(os.path.join(self.root,'toutiao_cat_data_train.csv'), index=False, header=True) df_valid.to_csv(os.path.join(self.root,'toutiao_cat_data_valid.csv'), index=False, header=True) df_test.to_csv(os.path.join(self.root,'toutiao_cat_data_test.csv'), index=False, header=True)
if self.train: self.df = pd.read_csv(os.path.join(self.root,'toutiao_cat_data_train.csv')) else: self.df = pd.read_csv(os.path.join(self.root,'toutiao_cat_data_valid.csv'))
def __len__(self) -> int: return len(self.df)
def __getitem__(self, index) -> Tuple[Any,Any]: """继承Dataset接口,根据index获取某个元素,返回数据和标签""" label = self.df.iloc[index]["label"] text = self.df.iloc[index]["text"] return self.classes_map[str(label)],text
if __name__ == "__main__": data = Dataset("toutiao-text-classfication-dataset") print(f'len={len(data)}')
|