|
06、词表封装、类别封装与数据集的封装_视频
高清视频下载地址【回复本帖可见】:
- # input file
- seg_train_file = './cnews_data/cnews.train.seg.txt'
- seg_test_file = './cnews_data/cnews.test.seg.txt'
- seg_val_file = './cnews_data/cnews.val.seg.txt'
- # 词表文件(只有训练集):词语到id的映射 字典
- vocab_file = './cnews_data/cnews.vocab.txt'
- # label 类别文件
- category_file = './cnews_data/cnews.category.txt'
- output_folder = './cnews_data/run_text_rnn'
- # if not os.path.exists(output_folder):
- # os.mkdir(output_folder)
- if not gfile.Exists(output_folder):
- gfile.MakeDirs(output_folder)
- class Vocab:
- """
- 词表数据的封装
- """
- def __init__(self, filename, num_word_threshould):
- # _开头的 私有成员变量 面向对象的封装 只有用函数去调用
- self._word_to_id = {}
- self._num_word_threshould = num_word_threshould
- self._unk_id = -1
- self._read_dict(filename)
- def _read_dict(self, filename):
- with open(filename, 'r', encoding='utf-8') as f:
- lines = f.readlines()
- for line in lines:
- word, frequency = line.strip('\r\n').split('\t')
- frequency = int(frequency)
- if frequency < self._num_word_threshould:
- continue
- # idx是递增的 是当前字典的长度
- idx = len(self._word_to_id)
- if word == '<UNK>':
- self._unk_id = idx
- self._word_to_id[word] = idx
- def word_to_id(self, word):
- # return self._word_to_id(k=word, default=self._unk_id)
- return self._word_to_id.get(word, self._unk_id)
- def sentence_to_id(self, sentence):
- # 有个问题 假如cur_word在字典中不存在 怎么办?
- # word_ids = [self._word_to_id[cur_word] for cur_word in sentence.split()]
- # word_ids = [self._word_to_id.get(cur_word, self._unk_id) for cur_word in sentence.split()]
- word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()]
- return word_ids
- def size(self):
- return len(self._word_to_id)
- @property
- def get_unk_id(self):
- return self._unk_id
- vocab = Vocab(vocab_file, hps.num_word_threshould)
- vocab_size = vocab.size()
- # print(vocab.size())
- logging.info('vocab.size:%d' % vocab_size)
- # test_str = '的 在 了 是 东方耀'
- # print(vocab.sentence_to_id(test_str))
- class CategoryDict:
- def __init__(self, filename):
- self._category_to_id = {}
- with open(filename, 'r', encoding='utf-8') as f:
- lines = f.readlines()
- for line in lines:
- category = line.strip('\r\n')
- idx = len(self._category_to_id)
- self._category_to_id[category] = idx
- def category_to_id(self, category_name):
- if not category_name in self._category_to_id:
- raise Exception('%s is not in category list' % category_name)
- return self._category_to_id[category_name]
- def size(self):
- return len(self._category_to_id)
- category_vocab = CategoryDict(category_file)
- test_str = '娱乐'
- logging.info('label: %s, id: %d' % (test_str, category_vocab.category_to_id(test_str)))
- num_classes = category_vocab.size()
- tf.logging.info('num_classes: %d' % num_classes)
- class TextDataSet:
- """文本数据集的封装"""
- def __init__(self, filename, vocab, category_vocab, num_timesteps):
- # filename 分词后的文件
- self._vocab = vocab
- self._category_vocab = category_vocab
- self._num_timesteps = num_timesteps
- # matrix
- self._inputs = []
- # vector
- self._outputs = []
- self._indicator = 0
- # 解析文件 把inputs与outputs填充的过程
- self._parse_file(filename)
- def _parse_file(self, filename):
- logging.info('Loading data from %s.' % filename)
- with open(filename, 'r', encoding='utf-8') as f:
- lines = f.readlines()
- for line in lines:
- label, content = line.strip('\r\n').split('\t')
- id_label = self._category_vocab.category_to_id(label)
- id_words = self._vocab.sentence_to_id(content)
- # 一行多少个id 与 num_timesteps比较 有多就截断 少了就补齐
- # 对齐操作 id_words长度与_num_timesteps之间比较
- id_words = id_words[0: self._num_timesteps]
- padding_num = self._num_timesteps - len(id_words)
- # get_unk_id加了@property 相当于调成员变量 而不是函数
- id_words = id_words + [self._vocab.get_unk_id for i in range(padding_num)]
- self._inputs.append(id_words)
- self._outputs.append(id_label)
- # 变成矩阵
- self._inputs = np.asarray(self._inputs, dtype=np.int32)
- self._outputs = np.asarray(self._outputs, dtype=np.int32)
- self._random_shuffle()
- def _random_shuffle(self):
- p = np.random.permutation(len(self._inputs))
- self._inputs = self._inputs[p]
- self._outputs = self._outputs[p]
- def next_batch(self, batch_size):
- end_indicator = self._indicator + batch_size
- if end_indicator > len(self._inputs):
- # 已经到了数据集的后面了 不足以取一次batch_size大小的数据了 指针初始化为0
- self._random_shuffle()
- self._indicator = 0
- end_indicator = self._indicator + batch_size
- if end_indicator > len(self._inputs):
- raise Exception('batch_size: %d is too large' % batch_size)
- # 下面是正常的取batch_size大小的数据
- batch_inputs = self._inputs[self._indicator: end_indicator]
- batch_outputs = self._outputs[self._indicator: end_indicator]
- # 不要忘记
- self._indicator = end_indicator
- return batch_inputs, batch_outputs
- train_dataset = TextDataSet(seg_train_file, vocab, category_vocab, hps.num_timesteps)
- test_dataset = TextDataSet(seg_test_file, vocab, category_vocab, hps.num_timesteps)
- val_dataset = TextDataSet(seg_val_file, vocab, category_vocab, hps.num_timesteps)
- #
- #
- print(train_dataset.next_batch(2))
- print(test_dataset.next_batch(2))
- print(val_dataset.next_batch(2))
复制代码
|
|