热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

朴素贝叶斯文本分类java实现

packagecom.data.ml.classify;importjava.io.File;importjava.util.ArrayList;importjava.util.C

package com.data.ml.classify;

import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.data.util.IoUtil;

public class NativeBayes {
    /**
     * 默认频率
     */
    private double defaultFreq = 0.1;
    
    /**
     * 训练数据的比例
     */
    private Double trainingPercent = 0.8;

    private Map> files_all = new HashMap>();

    private Map> files_train = new HashMap>();

    private Map> files_test = new HashMap>();

    public NativeBayes() {

    }

    /**
     * 每个分类的频率
     */
    private Map classFreq = new HashMap();
    
    private Map ClassProb = new HashMap();
    
    /**
     * 特征总数
     */
    private Set WordDict = new HashSet();
    
    private Map> classFeaFreq = new HashMap>();
    
    private Map> ClassFeaProb = new HashMap>();
    
    private Map ClassDefaultProb = new HashMap();
    
    /**
     * 计算准确率
     * @param reallist 真实类别
     * @param pridlist 预测类别
     */
    public void Evaluate(List reallist, List pridlist){
        double correctNum = 0.0;
        for (int i = 0; i ) {
            if(reallist.get(i) == pridlist.get(i)){
                correctNum += 1;
            }
        }
        double accuracy = correctNum / reallist.size();
        System.out.println("准确率为:" + accuracy);
    }
    
    /**
     * 计算精确率和召回率
     * @param reallist
     * @param pridlist
     * @param classname
     */
    public void CalPreRec(List reallist, List pridlist, String classname){
        double correctNum = 0.0;
        double allNum = 0.0;//测试数据中,某个分类的文章总数
        double preNum = 0.0;//测试数据中,预测为该分类的文章总数
        
        for (int i = 0; i ) {
            if(reallist.get(i) == classname){
                allNum += 1;
                if(reallist.get(i) == pridlist.get(i)){
                    correctNum += 1;
                }
            }
            if(pridlist.get(i) == classname){
                preNum += 1;
            }
        }
        System.out.println(classname + " 精确率(跟预测分类比较):" + correctNum / preNum + " 召回率(跟真实分类比较):" + correctNum / allNum);
    }
    
    /**
     * 用模型进行预测
     */
    public void PredictTestData() {
        List reallist=new ArrayList();
        List pridlist=new ArrayList();
        
        for (Entry> entry : files_test.entrySet()) {
            String realclassname = entry.getKey();
            List files = entry.getValue();

            
            for (String file : files) {
                reallist.add(realclassname);
                
                
                List classnamelist=new ArrayList();
                List scorelist=new ArrayList();
                for (Entry entry_1 : ClassProb.entrySet()) {
                    String classname = entry_1.getKey();
                    //先验概率
                    Double score = Math.log(entry_1.getValue());
                    
                    String[] words = IoUtil.readFromFile(new File(file)).split(" ");
                    for (String word : words) {
                        if(!WordDict.contains(word)){
                            continue;
                        }
                        
                        if(ClassFeaProb.get(classname).containsKey(word)){
                            score += Math.log(ClassFeaProb.get(classname).get(word));
                        }else{
                            score += Math.log(ClassDefaultProb.get(classname));
                        }
                    }
                    
                    classnamelist.add(classname);
                    scorelist.add(score);
                }
                
                Double maxProb = Collections.max(scorelist);
                int idx = scorelist.indexOf(maxProb);
                pridlist.add(classnamelist.get(idx));
            }
        }
        
        Evaluate(reallist, pridlist);
        
        for (String cname : files_test.keySet()) {
            CalPreRec(reallist, pridlist, cname);
        }
        
    }
    
    /**
     * 模型训练
     */
    public void createModel() {
        double sum = 0.0;
        for (Entry entry : classFreq.entrySet()) {
            sum+=entry.getValue();
        }
        for (Entry entry : classFreq.entrySet()) {
            ClassProb.put(entry.getKey(), entry.getValue()/sum);
        }
        
        
        for (Entry> entry : classFeaFreq.entrySet()) {
            sum = 0.0;
            String classname = entry.getKey();
            for (Entry entry_1 : entry.getValue().entrySet()){
                sum += entry_1.getValue();
            }
            double newsum = sum + WordDict.size()*defaultFreq;
            
            Map feaProb = new HashMap();
            ClassFeaProb.put(classname, feaProb);
            
            for (Entry entry_1 : entry.getValue().entrySet()){
                String word = entry_1.getKey();
                feaProb.put(word, (entry_1.getValue() +defaultFreq) /newsum);
            }
            ClassDefaultProb.put(classname, defaultFreq/newsum);
        }
    }
    
    /**
     * 加载训练数据
     */
    public void loadTrainData(){
        for (Entry> entry : files_train.entrySet()) {
            String classname = entry.getKey();
            List docs = entry.getValue();
            
            classFreq.put(classname, docs.size());
            
            Map feaFreq = new HashMap();
            classFeaFreq.put(classname, feaFreq);
            
            for (String doc : docs) {
                String[] words = IoUtil.readFromFile(new File(doc)).split(" ");
                for (String word : words) {
                    
                    WordDict.add(word);
                    
                    if(feaFreq.containsKey(word)){
                        int num = feaFreq.get(word) + 1;
                        feaFreq.put(word, num);
                    }else{
                        feaFreq.put(word, 1);
                    }
                }
            }    
            
            
        }
        System.out.println(classFreq.size()+" 分类, " + WordDict.size()+" 特征词");
    }
    
    /**
     * 将数据分为训练数据和测试数据
     * 
     * @param dataDir
     */
    public void splitData(String dataDir) {
        // 用文件名区分类别
        Pattern pat = Pattern.compile("\\d+([a-z]+?)\\.");
        dataDir = "testdata/allfiles";
        File f = new File(dataDir);
        File[] files = f.listFiles();
        for (File file : files) {
            String fname = file.getName();
            Matcher m = pat.matcher(fname);
            if (m.find()) {
                String cname = m.group(1);
                if (files_all.containsKey(cname)) {
                    files_all.get(cname).add(file.toString());
                } else {
                    List tmp = new ArrayList();
                    tmp.add(file.toString());
                    files_all.put(cname, tmp);
                }
            } else {
                System.out.println("err: " + file);
            }
        }

        System.out.println("统计数据:");
        for (Entry> entry : files_all.entrySet()) {
            String cname = entry.getKey();
            List value = entry.getValue();
            // System.out.println(cname + " : " + value.size());

            List train = new ArrayList();
            List test = new ArrayList();

            for (String str : value) {
                if (Math.random() <= trainingPercent) {// 80%用来训练 , 20%测试
                    train.add(str);
                } else {
                    test.add(str);
                }
            }

            files_train.put(cname, train);
            files_test.put(cname, test);
        }

        System.out.println("所有文件数:");
        printStatistics(files_all);
        System.out.println("训练文件数:");
        printStatistics(files_train);
        System.out.println("测试文件数:");
        printStatistics(files_test);

    }

    /**
     * 打印统计信息
     * 
     * @param m
     */
    public void printStatistics(Map> m) {
        for (Entry> entry : m.entrySet()) {
            String cname = entry.getKey();
            List value = entry.getValue();
            System.out.println(cname + " : " + value.size());
        }
        System.out.println("--------------------------------");
    }

    public static void main(String[] args) {
        NativeBayes bayes = new NativeBayes();
        bayes.splitData(null);
        bayes.loadTrainData();
        bayes.createModel();
        bayes.PredictTestData();

    }

}

所有文件数:
sports : 1018
auto : 1020
business : 1028
--------------------------------
训练文件数:
sports : 791
auto : 812
business : 808
--------------------------------
测试文件数:
sports : 227
auto : 208
business : 220
--------------------------------
3 分类, 39613 特征词
准确率为:0.9801526717557252
sports 精确率(跟预测分类比较):0.9956140350877193 召回率(跟真实分类比较):1.0
auto 精确率(跟预测分类比较):0.9579439252336449 召回率(跟真实分类比较):0.9855769230769231
business 精确率(跟预测分类比较):0.9859154929577465 召回率(跟真实分类比较):0.9545454545454546

统计数据:
所有文件数:
sports : 1018
auto : 1020
business : 1028
--------------------------------
训练文件数:
sports : 827
auto : 833
business : 825
--------------------------------
测试文件数:
sports : 191
auto : 187
business : 203
--------------------------------
3 分类, 39907 特征词
准确率为:0.9759036144578314
sports 精确率(跟预测分类比较):0.9894736842105263 召回率(跟真实分类比较):0.9842931937172775
auto 精确率(跟预测分类比较):0.9836956521739131 召回率(跟真实分类比较):0.9679144385026738
business 精确率(跟预测分类比较):0.9565217391304348 召回率(跟真实分类比较):0.9753694581280788


 


推荐阅读
  • Android源码深入理解JNI技术的概述和应用
    本文介绍了Android源码中的JNI技术,包括概述和应用。JNI是Java Native Interface的缩写,是一种技术,可以实现Java程序调用Native语言写的函数,以及Native程序调用Java层的函数。在Android平台上,JNI充当了连接Java世界和Native世界的桥梁。本文通过分析Android源码中的相关文件和位置,深入探讨了JNI技术在Android开发中的重要性和应用场景。 ... [详细]
  • 本文介绍了在满足特定条件时如何在输入字段中使用默认值的方法和相应的代码。当输入字段填充100或更多的金额时,使用50作为默认值;当输入字段填充有-20或更多(负数)时,使用-10作为默认值。文章还提供了相关的JavaScript和Jquery代码,用于动态地根据条件使用默认值。 ... [详细]
  • Mac OS 升级到11.2.2 Eclipse打不开了,报错Failed to create the Java Virtual Machine
    本文介绍了在Mac OS升级到11.2.2版本后,使用Eclipse打开时出现报错Failed to create the Java Virtual Machine的问题,并提供了解决方法。 ... [详细]
  • 本文介绍了如何在给定的有序字符序列中插入新字符,并保持序列的有序性。通过示例代码演示了插入过程,以及插入后的字符序列。 ... [详细]
  • eclipse学习(第三章:ssh中的Hibernate)——11.Hibernate的缓存(2级缓存,get和load)
    本文介绍了eclipse学习中的第三章内容,主要讲解了ssh中的Hibernate的缓存,包括2级缓存和get方法、load方法的区别。文章还涉及了项目实践和相关知识点的讲解。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • Go Cobra命令行工具入门教程
    本文介绍了Go语言实现的命令行工具Cobra的基本概念、安装方法和入门实践。Cobra被广泛应用于各种项目中,如Kubernetes、Hugo和Github CLI等。通过使用Cobra,我们可以快速创建命令行工具,适用于写测试脚本和各种服务的Admin CLI。文章还通过一个简单的demo演示了Cobra的使用方法。 ... [详细]
  • 本文介绍了iOS数据库Sqlite的SQL语句分类和常见约束关键字。SQL语句分为DDL、DML和DQL三种类型,其中DDL语句用于定义、删除和修改数据表,关键字包括create、drop和alter。常见约束关键字包括if not exists、if exists、primary key、autoincrement、not null和default。此外,还介绍了常见的数据库数据类型,包括integer、text和real。 ... [详细]
  • 本文讨论了在openwrt-17.01版本中,mt7628设备上初始化启动时eth0的mac地址总是随机生成的问题。每次随机生成的eth0的mac地址都会写到/sys/class/net/eth0/address目录下,而openwrt-17.01原版的SDK会根据随机生成的eth0的mac地址再生成eth0.1、eth0.2等,生成后的mac地址会保存在/etc/config/network下。 ... [详细]
  • (三)多表代码生成的实现方法
    本文介绍了一种实现多表代码生成的方法,使用了java代码和org.jeecg框架中的相关类和接口。通过设置主表配置,可以生成父子表的数据模型。 ... [详细]
  • 如何用JNI技术调用Java接口以及提高Java性能的详解
    本文介绍了如何使用JNI技术调用Java接口,并详细解析了如何通过JNI技术提高Java的性能。同时还讨论了JNI调用Java的private方法、Java开发中使用JNI技术的情况以及使用Java的JNI技术调用C++时的运行效率问题。文章还介绍了JNIEnv类型的使用方法,包括创建Java对象、调用Java对象的方法、获取Java对象的属性等操作。 ... [详细]
  • 使用eclipse创建一个Java项目的步骤
    本文介绍了使用eclipse创建一个Java项目的步骤,包括启动eclipse、选择New Project命令、在对话框中输入项目名称等。同时还介绍了Java Settings对话框中的一些选项,以及如何修改Java程序的输出目录。 ... [详细]
  • 本文概述了JNI的原理以及常用方法。JNI提供了一种Java字节码调用C/C++的解决方案,但引用类型不能直接在Native层使用,需要进行类型转化。多维数组(包括二维数组)都是引用类型,需要使用jobjectArray类型来存取其值。此外,由于Java支持函数重载,根据函数名无法找到对应的JNI函数,因此介绍了JNI函数签名信息的解决方案。 ... [详细]
  • Mono为何能跨平台
    概念JIT编译(JITcompilation),运行时需要代码时,将Microsoft中间语言(MSIL)转换为机器码的编译。CLR(CommonLa ... [详细]
  • 引号快捷键_首选项和设置——自定义快捷键
    3.3自定义快捷键(CustomizingHotkeys)ChemDraw快捷键由一个XML文件定义,我们可以根据自己的需要, ... [详细]
author-avatar
靠小号3
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有