1前言 在自然语言处理中,文本分类是非常普遍的应用,本文将介绍使用TensorFlow开发基于嵌入(Embedding)的文本分类模型,由于TensorFlow的API变化迅速且兼容性感人,因此本文均使用的截至2022年4月16日最新版的TensorFlow(tf)及相关库,主要包括:TensorFlow(v2。8。0),TensorFlowDatasets(tfdsv4。0。1)和TensorFlowText(tftextv2。8。1),如遇bug,请首先检查TensorFlow相关库的版本。此工作流主要使用的API有:tf。stringstfdstftexttf。data。Datasettf。keras(SequentialFunctionalAPI)2获取数据 TensorFlowDatasets(tfds)里含有非常多的示例数据〔1〕用于研究试验,本文使用经典的电影评论数据,进行情感二分类任务的研究。首先使用tfds的API直接加载数据,结果将存在一个tf。data。Dataset〔2〕对象中。importcollectionsimportpathlibimporttensorflowastffromtensorflow。kerasimportlayersfromtensorflow。kerasimportlossesfromtensorflow。kerasimportutilsfromtensorflow。keras。layersimportTextVectorizationimporttensorflowdatasetsastfdsimporttensorflowtextastftextimportplotly。expressaspximportmatplotlib。pyplotaspltBATCHSIZE32Trainingset。traindstfds。load(imdbreviews,splittrain〔:80〕,shufflefilesTrue,assupervisedTrue)Validationsetatf。data。Datasetobjectvaldstfds。load(imdbreviews,splittrain〔80:〕,shufflefilesTrue,assupervisedTrue)Checkthecountofrecordsprint(trainds。cardinality()。numpy())print(valds。cardinality()。numpy()) 返回值为:200005000 使用如下方法查看一条示例数据:fordata,labelintrainds。take(1):print(type(data))print(Text:,data。numpy())print(Label:,label。numpy()) 返回值为:classtensorflow。python。framework。ops。EagerTensorText:bThiswasanabsolutelyterriblemovie。DontbeluredinbyChristopherWalkenorMichaelIronside。Botharegreatactors,butthismustsimplybetheirworstroleinhistory。Eventheirgreatactingcouldnotredeemthismoviesridiculousstoryline。ThismovieisanearlyninetiesUSpropagandapiece。ThemostpatheticsceneswerethosewhentheColumbianrebelsweremakingtheircasesforrevolutions。MariaConchitaAlonsoappearedphony,andherpseudoloveaffairwithWalkenwasnothingbutapatheticemotionalpluginamoviethatwasdevoidofanyrealmeaning。Iamdisappointedthattherearemovieslikethis,ruiningactorslikeChristopherWalkensgoodname。Icouldbarelysitthroughit。Label:03文本预处理 该小节使用tftext和tf。stings的处理文本的API对数据进行处理,tf。data。Dataset能够很方便地将对应的函数映射到数据中,推荐学习和使用。3。1转换文字大小写 分类任务中字符大小写对模型预测没有贡献,因此对dataset使用map操作把所有字符转为小写,务必注意tf。data。Dataset里的数据格式。traindstrainds。map(lambdatext,label:(tftext。casefoldutf8(text),label))valdsvalds。map(lambdatext,label:(tftext。casefoldutf8(text),label))3。2文本格式化 该步骤对文本使用正则表达式进行格式化处理,如标点前后加上空格,利于后续步骤使用空格分词。strregexpattern〔(〔AZaz09(),!?〕,),(s,s,),(ve,ve),(nt,nt),(re,re),(d,d),(ll,ll),(,,,),(!,!),((,(),(),)),(?,?),(s{2,},)〕forpattern,rewriteinstrregexpattern:traindstrainds。map(lambdatext,label:(tf。strings。regexreplace(text,patternpattern,rewriterewrite),label))valdsvalds。map(lambdatext,label:(tf。strings。regexreplace(text,patternpattern,rewriterewrite),label))3。3构建词表 使用训练集构造词表(注意不要使用验证集或者测试集,会导致信息泄露),该步骤将字符映射到相应的索引,利于将数据转化为模型能够进行训练和预测的格式。Donotusevalidationsetasthatwillleadtodataleaktraintexttrainds。map(lambdatext,label:text)tokenizertftext。WhitespaceTokenizer()uniquetokenscollections。defaultdict(lambda:0)sentencelength〔〕fortextintraintext。asnumpyiterator():tokenstokenizer。tokenize(text)。numpy()sentencelength。append(len(tokens))fortokenintokens:uniquetokens〔token〕1checkouttheaveragesentencelength250tokensprint(sum(sentencelength)len(sentencelength))print10mostusedtokenstoken,frequencydview〔(v,k)fork,vinuniquetokens。items()〕dview。sort(reverseTrue)forv,kindview〔:10〕:print(s:d(k,v)) 返回值显示,高频使用的词都是英语中常见的字符:bthe:269406b,:221098band:131502ba:130309bof:116695bto:108605bis:88351bbr:81558bit:77094bin:75177 也可以使用图表直观地展示每个词的使用频率,这一步有利于帮助选择词表的大小。figpx。scatter(xrange(len(dview)),y〔cntforcnt,wordindview〕)fig。show() 字符的使用频率分布 由图可见,在七万多个字符中,许多字符出现的频率极低,因此选择词表大小为两万。3。4构建词表映射 使用TensorFlow的tf。lookup。StaticVocabularyTable对字符进行映射,其能将字符映射到对应的索引上,并使用一个简单的样本进行测试。keys〔tokenforcnt,tokenindview〕〔:vocabsize〕valuesrange(2,len(keys)2)Reserve0forpadding,1forOOVtokens。numoovbuckets1Note:mustassignthekeydtypeandvaluedtypewhenthekeysandvaluesarePythonarraysinittf。lookup。KeyValueTensorInitializer(keyskeys,valuesvalues,keydtypetf。string,valuedtypetf。int64)tabletf。lookup。StaticVocabularyTable(init,numoovbucketsnumoovbuckets)Testthelookuptablewithsampleinputinputtensortf。constant(〔emerson,lake,palmer,king〕)print(table〔inputtensor〕。numpy()) 输出为:array(〔20000,2065,14207,618〕) 接下来就可以将文本映射到索引上了,构造一个函数用于转化,并将它作用到数据集上:deftextindexlookup(text,label):tokenizedtokenizer。tokenize(text)vectorizedtable。lookup(tokenized)returnvectorized,labeltraindstrainds。map(textindexlookup)valdsvalds。map(textindexlookup)3。5配置数据集 借助tf。data。Dataset的cache和prefetchAPI,能够有效提高性能,cache方法将数据加载在内存中用于快速读写,而prefetch则能够在模型预测时同步处理数据,提高时间利用率。AUTOTUNEtf。data。AUTOTUNEdefconfiguredataset(dataset):returndataset。cache()。prefetch(buffersizeAUTOTUNE)traindsconfiguredataset(trainds)valdsconfiguredataset(valds) 文本长短不一,但神经网络需要输入数据具有固定的维度,因此对数据进行padding确保长度一致,并分批次。BATCHSIZE32traindstrainds。paddedbatch(BATCHSIZE)valdsvalds。paddedbatch(BATCHSIZE)3。6处理测试集 用于验证模型性能的测试集也可以使用同样的方式处理,确保模型可以正常预测:Testset。testdstfds。load(imdbreviews,splittest,batchsizeBATCHSIZE,shufflefilesTrue,assupervisedTrue)testdstestds。map(lambdatext,label:(tftext。casefoldutf8(text),label))forpattern,rewriteinstrregexpattern:testdstestds。map(lambdatext,label:(tf。strings。regexreplace(text,patternpattern,rewriterewrite),label))testdstestds。map(textindexlookup)testdsconfiguredataset(testds)testdstestds。paddedbatch(BATCHSIZE)4建立模型4。1使用SequentialAPI构建卷积神经网络vocabsize20forpaddingand1foroovtokendefcreatemodel(vocabsize,numlabels,dropoutrate):modeltf。keras。Sequential(〔tf。keras。layers。Embedding(vocabsize,128,maskzeroTrue),tf。keras。layers。Conv1D(32,3,paddingvalid,activationrelu,strides1),tf。keras。layers。MaxPooling1D(poolsize2),tf。keras。layers。Conv1D(64,4,paddingvalid,activationrelu,strides1),tf。keras。layers。MaxPooling1D(poolsize2),tf。keras。layers。Conv1D(128,5,paddingvalid,activationrelu,strides1),tf。keras。layers。GlobalMaxPooling1D(),tf。keras。layers。Dropout(dropoutrate),tf。keras。layers。Dense(numlabels)〕)returnmodeltf。keras。backend。clearsession()modelcreatemodel(vocabsizevocabsize,numlabels2,dropoutrate0。5)在SGD中使用momentum将显著提高收敛速度losslosses。SparseCategoricalCrossentropy(fromlogitsTrue)optimizertf。keras。optimizers。SGD(learningrate0。01,momentum0。9)model。compile(lossloss,optimizeroptimizer,metricsaccuracy)print(model。summary()) 输出为:Model:sequentialLayer(type)OutputShapeParamembedding(Embedding)(None,None,128)2560256conv1d(Conv1D)(None,None,32)12320maxpooling1d(MaxPooling1D(None,None,32)0)conv1d1(Conv1D)(None,None,64)8256maxpooling1d1(MaxPooling(None,None,64)01D)conv1d2(Conv1D)(None,None,128)41088globalmaxpooling1d(Globa(None,128)0lMaxPooling1D)dropout(Dropout)(None,128)0dense(Dense)(None,2)258Totalparams:2,622,178Trainableparams:2,622,178Nontrainableparams:0 接下来即可训练、评估模型:earlystoppingreducestheriskofoverfittingearlystoppingtf。keras。callbacks。EarlyStopping(patience10)epochs100historymodel。fit(xtrainds,validationdatavalds,epochsepochs,callbacks〔earlystopping〕)loss,accuracymodel。evaluate(testds)print(Loss:,loss)print(Accuracy:{:2。2}。format(accuracy)) 考虑到模型结构简单,效果还可以接受:782782〔〕57s72mssteploss:0。4583accuracy:0。8678Loss:0。45827823877334595Accuracy:86。784。2使用FunctionalAPI构建双向LSTM 步骤与使用SequentialAPI类似,但FunctionalAPI更为灵活。inputtf。keras。layers。Input(〔None〕)xtf。keras。layers。Embedding(inputdimvocabsize,outputdim128,UsemaskingtohandlethevariablesequencelengthsmaskzeroTrue)(input)xtf。keras。layers。Bidirectional(tf。keras。layers。LSTM(64))(x)xtf。keras。layers。Dense(64,activationrelu)(x)xtf。keras。layers。Dropout(dropoutrate)(x)outputtf。keras。layers。Dense(numlabels)(x)lstmmodeltf。keras。Model(inputsinput,outputsoutput,nametextlstmmodel)losslosses。SparseCategoricalCrossentropy(fromlogitsTrue)optimizertf。keras。optimizers。SGD(learningrate0。01,momentum0。9)lstmmodel。compile(lossloss,optimizeroptimizer,metricsaccuracy)lstmmodel。summary() 输出为:Model:textlstmmodelLayer(type)OutputShapeParaminput5(InputLayer)〔(None,None)〕0embedding5(Embedding)(None,None,128)2560256bidirectional4(Bidirectio(None,128)98816nal)dense4(Dense)(None,64)8256dropout2(Dropout)(None,64)0dense5(Dense)(None,2)130Totalparams:2,667,458Trainableparams:2,667,458Nontrainableparams:0 同样地,对模型进行训练与预测:history2lstmmodel。fit(xtrainds,validationdatavalds,epochsepochs,callbacks〔earlystopping〕)loss,accuracylstmmodel。evaluate(testds)print(Loss:,loss)print(Accuracy:{:2。2}。format(accuracy)) 考虑到模型结构简单,效果还可以接受:782782〔〕84s106mssteploss:0。4105accuracy:0。8160Loss:0。4105057716369629Accuracy:81。605总结 关于文本分类,还有许多新的技术可以尝试,上述工作流中也还有许多决策可以做试验(炼丹),本文旨在使用最新的TensorFlowAPI过一遍文本分类任务中的重要知识点和常用API,实际工作中仍有许多地方可以优化。希望这次的分享对你有帮助,欢迎在评论区留言讨论!参考资料 〔1〕TensorFlowDatasets数据集:https:www。tensorflow。orgdatasets 〔2〕tf。data。Dataset:https:www。tensorflow。orgapidocspythontfdataDataset