LDAでトピック分類を行ってみた
LDA(潜在的ディレクトリ配分法、latent dirichelet allocation)を使ってトピック分類を行った。
同じく文章分類を行うことができるナイーブベイズは教師あり学習なのに対して、
LDAは教師なし学習なのでラベルが不要となる。
データセット
- 20 newsgroups text dataset
学習
from sklearn.datasets import fetch_20newsgroups from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.decomposition import LatentDirichletAllocation import mglearn import numpy as np #data categories = ['rec.sport.baseball', 'rec.sport.hockey', \ 'comp.sys.mac.hardware', 'comp.windows.x'] twenty_train = fetch_20newsgroups(subset='train',categories=categories, \ shuffle=True, random_state=42) twenty_test = fetch_20newsgroups(subset='test',categories=categories, \ shuffle=True, random_state=42) tfidf_vec = TfidfVectorizer(lowercase=True, stop_words='english', \ max_df = 0.1, min_df = 5).fit(twenty_train.data) X_train = tfidf_vec.transform(twenty_train.data) X_test = tfidf_vec.transform(twenty_test.data) feature_names = tfidf_vec.get_feature_names() #print(feature_names[1000:1050]) #print() # train topic_num=4 lda =LatentDirichletAllocation(n_components=topic_num, max_iter=50, \ learning_method='batch', random_state=0, n_jobs=-1) lda.fit(X_train)
オリジナルのdataからtfidfを使用して特徴的な単語を抽出し、今回の学習対象とする
推論
# visualizer print("") sorting = np.argsort(lda.components_, axis=1)[:, ::-1] mglearn.tools.print_topics(topics=range(topic_num), feature_names=np.array(feature_names), topics_per_chunk=topic_num, sorting=sorting,n_words=10) # predict text11="an American multinational technology company headquartered in Cupertino, "+ \ "California, that designs, develops, and sells consumer electronics,"+ \ "computer software, and online services." text12="The company's hardware products include the iPhone smartphone,"+ \ "the iPad tablet computer, the Mac personal computer,"+ \ "the iPod portable media player, the Apple Watch smartwatch,"+ \ "the Apple TV digital media player, and the HomePod smart speaker." test1=[text11,text12] X_test1 = tfidf_vec.transform(test1) lda_test1 = lda.transform(X_test1) for i,lda in enumerate(lda_test1): print("### ",i) topicid=[i for i, x in enumerate(lda) if x == max(lda)] print(text11) print(lda," >>> topic",topicid) print("")
結果
それぞれの単語がどのようにトピック分類しているかを確認
topic 0 topic 1 topic 2 topic 3 -------- -------- -------- -------- nhl window mac wpi toronto mit apple nada teams motif drive kth league uk monitor hcf player server quadra jhunix roger windows se jhu pittsburgh program scsi unm cmu widget card admiral runs ac simms liu fan file centris carina
各テキストがどのトピックに分類されるか確認
### 0 an American multinational technology company headquartered in Cupertino, California, that designs, develops, and sells consumer electronics,computer software, and online services. [0.06391161 0.06149079 0.81545564 0.05914196] >>> topic [2] ### 1 an American multinational technology company headquartered in Cupertino, California, that designs, develops, and sells consumer electronics,computer software, and online services. [0.34345051 0.05899806 0.54454404 0.05300738] >>> topic [2]
その他
- CountVectorizer 単語をbow化する
- TfidfTransformer TfidfVectorizerの内部処理。詳細は確認中