import tensorflow as tf import os import gensim import re import jieba.posseg
as pseg from gensim.models.doc2vec import Doc2Vec from loadData import loadData
tf.flags.DEFINE_string("base_dir", ".", "files base_dir")
tf.flags.DEFINE_string("train_dir", ".\\train", "trainning files base_dir")
tf.flags.DEFINE_string("test_dir", ".\\test", "test files base_dir")
tf.flags.DEFINE_string("model_dir", "./doc2vecmodel", "Model directory from
training run") tf.flags.DEFINE_integer('vector_dim', 500,'dimensionality of
characters') tf.flags.DEFINE_integer('epoch_num', 70,'the number of epoch')
tf.flags.DEFINE_integer('min_count', 1,'ignore the words which freq lower than
min_count') tf.flags.DEFINE_integer('window', 3,'the max distance between
relative content') tf.flags.DEFINE_integer('negative', 5,'the number of
negative that we can accept') tf.flags.DEFINE_integer('workers', 4,'the module
number of worker') FLAGS = tf.flags.FLAGS FLAGS.is_parsed() print(
"\nParameters:") for attr, value in sorted(FLAGS.__flags.items()): print("{}={}"
.format(attr.upper(), value)) print("") class Singleton(object): def __new__
(cls, *args, **kw): if not hasattr(cls, '_instance'): orig = super(Singleton,
cls) cls._instance = orig.__new__(cls, *args, **kw)return cls._instance class
retrieve(Singleton): doc_dict = {} # 由编号映射文章ID的字典doc_dict的key和value
分别为编号(数据库的id)和对应文章ID model_dm = None #生成的模型 def __init__(self):
self.load_doc_index()# 第一步,训练模型前,先将语料整理成规定的形式,这里用到TaggedDocument模型 def
get_trainset(self): x_train = [] list_name = os.listdir(FLAGS.train_dir) #
用于训练模型的语料先进行预处理 TaggededDocument = gensim.models.doc2vec.TaggedDocument #
输入输出内容都为 词袋 + tag列表, 作用是记录每一篇文章的大致内容,并给该文章编号 load = loadData() for name in
list_name: user_file = os.path.join(FLAGS.train_dir, name)# 语料预处理 if not
os.path.isdir(user_file): data = open(user_file,mode='rb').read() item =
self.getInfoDetail(data.decode('utf-8')) index =
if index == -1: index = load.insertInfo(item)# 每一篇文章需要一个对应的编号
self.doc_dict[index] = name.strip(".txt") line = '' if '公司名称' in item: line =
line + item['公司名称'] if '经营范围' in item: line = line + '/t' + item['经营范围'] if
line =='': line = data.decode('utf-8') words =
self.seperate_line(self.clean_str(line)) x_train.append(TaggededDocument(words,
tags=[index]))return x_train # 第二步,初始化训练模型的参数,再保存训练结果以释放内存 def train(self,
x_train, size=500, epoch_num=1): self.model_dm = gensim.models.Doc2Vec(x_train,
min_count=FLAGS.min_count, window=FLAGS.window, size=size, sample=1e-3,
negative=FLAGS.negative, workers=FLAGS.workers)# 模型的初始化,设置参数 # 提供x_train可初始化,
min_cout 忽略总频率低于这个的所有单词, window 预测的词与上下文词之间最大的距离, 用于预测 size 特征向量的维数 negative
接受杂质的个数 worker 工作模块数 self.model_dm.train(x_train,
total_examples=self.model_dm.corpus_count, epochs=epoch_num)# corpus_count是文件个数
epochs 训练次数 self.model_dm.save(FLAGS.model_dir) # 保存模型训练结果,释放内存空间,后续可用load加载
return self.model_dm #第三步,利用训练好的模型计算一个文章内容的相似度 def getMatchInfos(self, text):
matchInfos = [] load = loadData() self.load_doc_index()# 加载index_file
self.model_dm = Doc2Vec.load(FLAGS.model_dir)# 加载训练的模型
model_dm输出类似Doc2Vec(dm/m,d500,n5,w3,s0.001,t4) test_text =
self.seperate_line(self.clean_str(text)) inferred_vector_dm =
self.model_dm.infer_vector(test_text) sims =
self.model_dm.docvecs.most_similar([inferred_vector_dm], topn=5) for index, sim
in sims: print(self.doc_dict[index]) print(sim) # doc = x_train[int(index)] #
doc = doc[0] # doc包括词袋和编号,这里只要词袋 # for word in doc: # print(word) doc =
load.getMatchInfo(index)#从数据库读取 matchInfos.append(doc) return matchInfos #
第四步,将字典内容写入文档方便查阅,下次打开程序可以用另外的函数加载,不用重新 def save_doc_index(self): index_file =
os.path.join(FLAGS.base_dir,"index_file.txt") lines = "" for index in
self.doc_dict: lines += str(index) +' ' + self.doc_dict[index] + '\n' f =
open(index_file,'w') f.write(lines) f.close() def load_doc_index(self):
self.doc_dict = {} index_file = os.path.join(FLAGS.base_dir,"index_file.txt") if
os.path.exists(index_file): f = open(index_file) lines = f.readlines()#
把文件内容读出来存到lines 再关掉,不占内存 f.close() for line in lines: line = line.strip()
tokens = line.split(" ") self.doc_dict[int(tokens[0])] = tokens[1] return
self.doc_dictdef setModel(self): x_train = self.get_trainset() # 获取预处理的语料
self.save_doc_index()# 保存index_file self.model_dm = self.train(x_train,
epoch_num=FLAGS.epoch_num)# 训练模型,若已经训练过可以省略这步 return("set model success!") def
clean_str(self, string): string = re.sub('\s+', "", string) r1 =
u'[A-Za-z0-9’!"#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+' string =
re.sub(r1,' ', string) return string.strip() def seperate_line(self, line):
line = pseg.cut(line) new_line = []for words, flag in line: if flag == 'nr' or
flag =='ns': continue if len(flag) == 0: continue if flag[0:1] != 'n' and flag
!='v': continue new_line.append(words) #return ''.join([word + " " for word in
new_line]) return new_line def getInfoDetail(self, text): new_text = {} t_arr =
text.split("$#&#$") for item in t_arr: i_arr = item.split(":",1) if len(i_arr) >
1: new_text[i_arr[0]] = i_arr[1] return new_text def getIndexFromDoct(self,
name): if len(self.doc_dict) > 0: for key, value in self.doc_dict.items(): if
value == name:return key return -1 if __name__ == '__main__': # if __name__ ==
'__main__' 函数只有直接当作脚本执行时才有效,Import到其他模块时无效 r = retrieve() x_train, doc_dict =
r.get_trainset()# 获取预处理的语料 r.save_doc_index(doc_dict) # 保存index_file doc_dict =
r.load_doc_index()# 加载index_file model_dm =
r.train(x_train,epoch_num=FLAGS.epoch_num)# 训练模型,若已经训练过可以省略这步

邮箱:[email protected]