东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 16148|回复: 25
打印 上一主题 下一主题

[视频教程] 06、词表封装、类别封装与数据集的封装_视频

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14429
QQ
跳转到指定楼层
楼主
发表于 2019-3-22 09:53:17 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式


06、词表封装、类别封装与数据集的封装_视频


高清视频下载地址【回复本帖可见】:
游客,如果您要查看本帖隐藏内容请回复




  1. # input file
  2. seg_train_file = './cnews_data/cnews.train.seg.txt'
  3. seg_test_file = './cnews_data/cnews.test.seg.txt'
  4. seg_val_file = './cnews_data/cnews.val.seg.txt'

  5. # 词表文件(只有训练集):词语到id的映射 字典
  6. vocab_file = './cnews_data/cnews.vocab.txt'
  7. # label 类别文件
  8. category_file = './cnews_data/cnews.category.txt'

  9. output_folder = './cnews_data/run_text_rnn'

  10. # if not os.path.exists(output_folder):
  11. #     os.mkdir(output_folder)

  12. if not gfile.Exists(output_folder):
  13.     gfile.MakeDirs(output_folder)


  14. class Vocab:
  15.     """
  16.     词表数据的封装
  17.     """
  18.     def __init__(self, filename, num_word_threshould):
  19.         # _开头的 私有成员变量 面向对象的封装 只有用函数去调用
  20.         self._word_to_id = {}
  21.         self._num_word_threshould = num_word_threshould
  22.         self._unk_id = -1
  23.         self._read_dict(filename)

  24.     def _read_dict(self, filename):
  25.         with open(filename, 'r', encoding='utf-8') as f:
  26.             lines = f.readlines()
  27.         for line in lines:
  28.             word, frequency = line.strip('\r\n').split('\t')
  29.             frequency = int(frequency)
  30.             if frequency < self._num_word_threshould:
  31.                 continue
  32.             # idx是递增的 是当前字典的长度
  33.             idx = len(self._word_to_id)
  34.             if word == '<UNK>':
  35.                 self._unk_id = idx
  36.             self._word_to_id[word] = idx

  37.     def word_to_id(self, word):
  38.         # return self._word_to_id(k=word, default=self._unk_id)
  39.         return self._word_to_id.get(word, self._unk_id)

  40.     def sentence_to_id(self, sentence):
  41.         # 有个问题 假如cur_word在字典中不存在 怎么办?
  42.         # word_ids = [self._word_to_id[cur_word] for cur_word in sentence.split()]
  43.         # word_ids = [self._word_to_id.get(cur_word, self._unk_id) for cur_word in sentence.split()]
  44.         word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()]
  45.         return word_ids

  46.     def size(self):
  47.         return len(self._word_to_id)

  48.     @property
  49.     def get_unk_id(self):
  50.         return self._unk_id


  51. vocab = Vocab(vocab_file, hps.num_word_threshould)
  52. vocab_size = vocab.size()
  53. # print(vocab.size())
  54. logging.info('vocab.size:%d' % vocab_size)

  55. # test_str = '的 在 了 是 东方耀'
  56. # print(vocab.sentence_to_id(test_str))


  57. class CategoryDict:
  58.     def __init__(self, filename):
  59.         self._category_to_id = {}
  60.         with open(filename, 'r', encoding='utf-8') as f:
  61.             lines = f.readlines()
  62.         for line in lines:
  63.             category = line.strip('\r\n')
  64.             idx = len(self._category_to_id)
  65.             self._category_to_id[category] = idx

  66.     def category_to_id(self, category_name):
  67.         if not category_name in self._category_to_id:
  68.             raise Exception('%s is not in category list' % category_name)
  69.         return self._category_to_id[category_name]

  70.     def size(self):
  71.         return len(self._category_to_id)


  72. category_vocab = CategoryDict(category_file)
  73. test_str = '娱乐'
  74. logging.info('label: %s, id: %d' % (test_str, category_vocab.category_to_id(test_str)))
  75. num_classes = category_vocab.size()
  76. tf.logging.info('num_classes: %d' % num_classes)


  77. class TextDataSet:
  78.     """文本数据集的封装"""
  79.     def __init__(self, filename, vocab, category_vocab, num_timesteps):
  80.         # filename 分词后的文件
  81.         self._vocab = vocab
  82.         self._category_vocab = category_vocab
  83.         self._num_timesteps = num_timesteps
  84.         # matrix
  85.         self._inputs = []
  86.         # vector
  87.         self._outputs = []
  88.         self._indicator = 0
  89.         # 解析文件 把inputs与outputs填充的过程
  90.         self._parse_file(filename)

  91.     def _parse_file(self, filename):
  92.         logging.info('Loading data from %s.' % filename)
  93.         with open(filename, 'r', encoding='utf-8') as f:
  94.             lines = f.readlines()
  95.         for line in lines:
  96.             label, content = line.strip('\r\n').split('\t')
  97.             id_label = self._category_vocab.category_to_id(label)
  98.             id_words = self._vocab.sentence_to_id(content)
  99.             # 一行多少个id 与 num_timesteps比较 有多就截断 少了就补齐
  100.             # 对齐操作 id_words长度与_num_timesteps之间比较
  101.             id_words = id_words[0: self._num_timesteps]
  102.             padding_num = self._num_timesteps - len(id_words)
  103.             # get_unk_id加了@property 相当于调成员变量 而不是函数
  104.             id_words = id_words + [self._vocab.get_unk_id for i in range(padding_num)]
  105.             self._inputs.append(id_words)
  106.             self._outputs.append(id_label)
  107.         # 变成矩阵
  108.         self._inputs = np.asarray(self._inputs, dtype=np.int32)
  109.         self._outputs = np.asarray(self._outputs, dtype=np.int32)
  110.         self._random_shuffle()

  111.     def _random_shuffle(self):
  112.         p = np.random.permutation(len(self._inputs))
  113.         self._inputs = self._inputs[p]
  114.         self._outputs = self._outputs[p]

  115.     def next_batch(self, batch_size):
  116.         end_indicator = self._indicator + batch_size
  117.         if end_indicator > len(self._inputs):
  118.             # 已经到了数据集的后面了 不足以取一次batch_size大小的数据了 指针初始化为0
  119.             self._random_shuffle()
  120.             self._indicator = 0
  121.             end_indicator = self._indicator + batch_size
  122.         if end_indicator > len(self._inputs):
  123.             raise Exception('batch_size: %d is too large' % batch_size)
  124.         # 下面是正常的取batch_size大小的数据
  125.         batch_inputs = self._inputs[self._indicator: end_indicator]
  126.         batch_outputs = self._outputs[self._indicator: end_indicator]
  127.         # 不要忘记
  128.         self._indicator = end_indicator
  129.         return batch_inputs, batch_outputs


  130. train_dataset = TextDataSet(seg_train_file, vocab, category_vocab, hps.num_timesteps)
  131. test_dataset = TextDataSet(seg_test_file, vocab, category_vocab, hps.num_timesteps)
  132. val_dataset = TextDataSet(seg_val_file, vocab, category_vocab, hps.num_timesteps)
  133. #
  134. #
  135. print(train_dataset.next_batch(2))
  136. print(test_dataset.next_batch(2))
  137. print(val_dataset.next_batch(2))
复制代码


让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

0

主题

137

帖子

604

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
604
沙发
发表于 2019-3-22 14:08:52 | 只看该作者
6
回复

使用道具 举报

0

主题

243

帖子

796

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
796
板凳
发表于 2019-3-22 20:58:51 | 只看该作者
发送到发送到
回复

使用道具 举报

0

主题

28

帖子

62

积分

注册会员

Rank: 2

积分
62
地板
发表于 2019-4-11 09:04:07 | 只看该作者
ai111.vipai111.vipai111.vipai111.vipai111.vip
回复

使用道具 举报

0

主题

319

帖子

724

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
724
5#
发表于 2019-5-15 09:33:17 | 只看该作者
06、词表封装、类别封装与数据集的封装_视频
回复

使用道具 举报

0

主题

205

帖子

460

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
460
6#
发表于 2019-5-26 18:09:05 | 只看该作者
谢谢楼主分享
回复

使用道具 举报

0

主题

364

帖子

864

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
864
7#
发表于 2019-5-29 15:12:40 | 只看该作者
1111111
回复

使用道具 举报

0

主题

155

帖子

324

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
324
8#
发表于 2019-6-16 21:03:17 | 只看该作者
1
回复

使用道具 举报

0

主题

266

帖子

586

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
586
9#
发表于 2019-7-17 13:59:57 | 只看该作者
11111
回复

使用道具 举报

0

主题

287

帖子

784

积分

2W人工智能培训

Rank: 10Rank: 10Rank: 10

积分
784
10#
发表于 2019-9-6 10:58:56 | 只看该作者
好视频看看
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-20 11:24 , Processed in 0.178843 second(s), 18 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表