东方耀AI技术分享

标题: 06、词表封装、类别封装与数据集的封装_视频 [打印本页]

作者: 东方耀    时间: 2019-3-22 09:53
标题: 06、词表封装、类别封装与数据集的封装_视频


06、词表封装、类别封装与数据集的封装_视频


高清视频下载地址【回复本帖可见】:



  1. # input file
  2. seg_train_file = './cnews_data/cnews.train.seg.txt'
  3. seg_test_file = './cnews_data/cnews.test.seg.txt'
  4. seg_val_file = './cnews_data/cnews.val.seg.txt'

  5. # 词表文件(只有训练集):词语到id的映射 字典
  6. vocab_file = './cnews_data/cnews.vocab.txt'
  7. # label 类别文件
  8. category_file = './cnews_data/cnews.category.txt'

  9. output_folder = './cnews_data/run_text_rnn'

  10. # if not os.path.exists(output_folder):
  11. #     os.mkdir(output_folder)

  12. if not gfile.Exists(output_folder):
  13.     gfile.MakeDirs(output_folder)


  14. class Vocab:
  15.     """
  16.     词表数据的封装
  17.     """
  18.     def __init__(self, filename, num_word_threshould):
  19.         # _开头的 私有成员变量 面向对象的封装 只有用函数去调用
  20.         self._word_to_id = {}
  21.         self._num_word_threshould = num_word_threshould
  22.         self._unk_id = -1
  23.         self._read_dict(filename)

  24.     def _read_dict(self, filename):
  25.         with open(filename, 'r', encoding='utf-8') as f:
  26.             lines = f.readlines()
  27.         for line in lines:
  28.             word, frequency = line.strip('\r\n').split('\t')
  29.             frequency = int(frequency)
  30.             if frequency < self._num_word_threshould:
  31.                 continue
  32.             # idx是递增的 是当前字典的长度
  33.             idx = len(self._word_to_id)
  34.             if word == '<UNK>':
  35.                 self._unk_id = idx
  36.             self._word_to_id[word] = idx

  37.     def word_to_id(self, word):
  38.         # return self._word_to_id(k=word, default=self._unk_id)
  39.         return self._word_to_id.get(word, self._unk_id)

  40.     def sentence_to_id(self, sentence):
  41.         # 有个问题 假如cur_word在字典中不存在 怎么办?
  42.         # word_ids = [self._word_to_id[cur_word] for cur_word in sentence.split()]
  43.         # word_ids = [self._word_to_id.get(cur_word, self._unk_id) for cur_word in sentence.split()]
  44.         word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()]
  45.         return word_ids

  46.     def size(self):
  47.         return len(self._word_to_id)

  48.     @property
  49.     def get_unk_id(self):
  50.         return self._unk_id


  51. vocab = Vocab(vocab_file, hps.num_word_threshould)
  52. vocab_size = vocab.size()
  53. # print(vocab.size())
  54. logging.info('vocab.size:%d' % vocab_size)

  55. # test_str = '的 在 了 是 东方耀'
  56. # print(vocab.sentence_to_id(test_str))


  57. class CategoryDict:
  58.     def __init__(self, filename):
  59.         self._category_to_id = {}
  60.         with open(filename, 'r', encoding='utf-8') as f:
  61.             lines = f.readlines()
  62.         for line in lines:
  63.             category = line.strip('\r\n')
  64.             idx = len(self._category_to_id)
  65.             self._category_to_id[category] = idx

  66.     def category_to_id(self, category_name):
  67.         if not category_name in self._category_to_id:
  68.             raise Exception('%s is not in category list' % category_name)
  69.         return self._category_to_id[category_name]

  70.     def size(self):
  71.         return len(self._category_to_id)


  72. category_vocab = CategoryDict(category_file)
  73. test_str = '娱乐'
  74. logging.info('label: %s, id: %d' % (test_str, category_vocab.category_to_id(test_str)))
  75. num_classes = category_vocab.size()
  76. tf.logging.info('num_classes: %d' % num_classes)


  77. class TextDataSet:
  78.     """文本数据集的封装"""
  79.     def __init__(self, filename, vocab, category_vocab, num_timesteps):
  80.         # filename 分词后的文件
  81.         self._vocab = vocab
  82.         self._category_vocab = category_vocab
  83.         self._num_timesteps = num_timesteps
  84.         # matrix
  85.         self._inputs = []
  86.         # vector
  87.         self._outputs = []
  88.         self._indicator = 0
  89.         # 解析文件 把inputs与outputs填充的过程
  90.         self._parse_file(filename)

  91.     def _parse_file(self, filename):
  92.         logging.info('Loading data from %s.' % filename)
  93.         with open(filename, 'r', encoding='utf-8') as f:
  94.             lines = f.readlines()
  95.         for line in lines:
  96.             label, content = line.strip('\r\n').split('\t')
  97.             id_label = self._category_vocab.category_to_id(label)
  98.             id_words = self._vocab.sentence_to_id(content)
  99.             # 一行多少个id 与 num_timesteps比较 有多就截断 少了就补齐
  100.             # 对齐操作 id_words长度与_num_timesteps之间比较
  101.             id_words = id_words[0: self._num_timesteps]
  102.             padding_num = self._num_timesteps - len(id_words)
  103.             # get_unk_id加了@property 相当于调成员变量 而不是函数
  104.             id_words = id_words + [self._vocab.get_unk_id for i in range(padding_num)]
  105.             self._inputs.append(id_words)
  106.             self._outputs.append(id_label)
  107.         # 变成矩阵
  108.         self._inputs = np.asarray(self._inputs, dtype=np.int32)
  109.         self._outputs = np.asarray(self._outputs, dtype=np.int32)
  110.         self._random_shuffle()

  111.     def _random_shuffle(self):
  112.         p = np.random.permutation(len(self._inputs))
  113.         self._inputs = self._inputs[p]
  114.         self._outputs = self._outputs[p]

  115.     def next_batch(self, batch_size):
  116.         end_indicator = self._indicator + batch_size
  117.         if end_indicator > len(self._inputs):
  118.             # 已经到了数据集的后面了 不足以取一次batch_size大小的数据了 指针初始化为0
  119.             self._random_shuffle()
  120.             self._indicator = 0
  121.             end_indicator = self._indicator + batch_size
  122.         if end_indicator > len(self._inputs):
  123.             raise Exception('batch_size: %d is too large' % batch_size)
  124.         # 下面是正常的取batch_size大小的数据
  125.         batch_inputs = self._inputs[self._indicator: end_indicator]
  126.         batch_outputs = self._outputs[self._indicator: end_indicator]
  127.         # 不要忘记
  128.         self._indicator = end_indicator
  129.         return batch_inputs, batch_outputs


  130. train_dataset = TextDataSet(seg_train_file, vocab, category_vocab, hps.num_timesteps)
  131. test_dataset = TextDataSet(seg_test_file, vocab, category_vocab, hps.num_timesteps)
  132. val_dataset = TextDataSet(seg_val_file, vocab, category_vocab, hps.num_timesteps)
  133. #
  134. #
  135. print(train_dataset.next_batch(2))
  136. print(test_dataset.next_batch(2))
  137. print(val_dataset.next_batch(2))
复制代码



作者: lovestudy    时间: 2019-3-22 14:08
6

作者: weiguangnixia    时间: 2019-3-22 20:58
发送到发送到
作者: zjshuai    时间: 2019-4-11 09:04
ai111.vipai111.vipai111.vipai111.vipai111.vip
作者: m379896771    时间: 2019-5-15 09:33
06、词表封装、类别封装与数据集的封装_视频
作者: 阳阳11    时间: 2019-5-26 18:09
谢谢楼主分享
作者: jlbu    时间: 2019-5-29 15:12
1111111
作者: Sy赵小赵    时间: 2019-6-16 21:03
1
作者: 何马    时间: 2019-7-17 13:59
11111
作者: jhcao23    时间: 2019-9-6 10:58
好视频看看
作者: malcody9    时间: 2019-9-19 15:19
2333
作者: 清远    时间: 2019-9-29 08:08
1
作者: 吧卟啊吧卟    时间: 2019-10-30 16:14
666
作者: fglbee    时间: 2019-12-7 10:56
this is good idea
作者: p2020    时间: 2019-12-18 20:15
11111111111111111111111
作者: arnold    时间: 2020-1-20 15:21
6666666666
作者: xsoft    时间: 2020-2-3 15:29
谢谢老师提供的资料。
作者: judson    时间: 2020-4-4 12:01
1
作者: luotuo    时间: 2020-4-23 16:30
谢谢资料
作者: 周末下雨的机场    时间: 2020-5-21 14:13
dddddd
作者: ice_spring    时间: 2020-5-30 09:51
分词技术
作者: Lewis    时间: 2020-7-21 16:24
6666666666666666
作者: zkq0519    时间: 2020-7-26 17:48
1
作者: hu007    时间: 2021-2-6 04:02
kxsp'cj sj cuwihduiwcdnjcjndso;jco;sdjc;sd mco;sd
作者: fengxiaoyang    时间: 2022-3-7 21:00
11
作者: 大水牛    时间: 2022-9-23 06:47
1




欢迎光临 东方耀AI技术分享 (http://www.ai111.vip/) Powered by Discuz! X3.4