跳到主要內容

java實作隱馬可夫模式(Hidden Markov Model)-猜拳預測

Hidden Markov Model(HMM)是機器學習領域中常用的一種模型,諸如語音辨識、手勢辨識都在此範圍內,而在了解隱馬可夫模型之前可先了解馬可夫鏈模式,所謂馬可夫鏈即是同類型的事件(不同的狀態)依序所發生的機率,一樣可以用來預測猜拳,首先我們將剪刀、石頭、布分別以數值0、1、2表示,然而玩家出過的拳為010101101,假設觀察值為3則我們可以將其分解為010、101、011、110而出現次數分別為010*2、101*3、011*1、110*1,由於觀察值為3故我們可以從序列中倒數前(3-1)個來分析下一拳可能的值,故可能為010、011、012但我們可從統計數中得知010*2、011*1、012從未出現,因此可能的值為010,此便為馬可夫鏈預測方式。

而為什麼馬可夫鏈即可預測猜拳還要使用隱馬可夫模型呢?

因為在猜拳中我們只關心已知的狀態(剪刀、石頭、布),故此時我們若不考慮狀態,單純從觀察值去推估出隱藏的狀態或許能更精確的預測出下一拳可能值。

實作說明:


1.      序列:[0120120120]
2.      觀察值:2
3.      狀態矩陣
(1)   由於觀察值為2,代表狀態從序列中第二個位置之後開始計算。
(2)   [0120120120]
狀態
次數
0
3
1
2
2
3
4.      狀態轉移矩陣
(1)    由於觀察值為2,代表狀態從序列中第二個位置之後開始計算。
(2)    [0120120120]
狀態
次數
00
0
01
2
02
0
10
0
11
0
12
2
20
3
21
0
22
0
5.      觀察值矩陣
(1)   從頭開始計算,由於windowsize2則舉例012代表01201為觀察值、2為狀態。
狀態
次數
000
0
010
0
020
0
100
0
110
0
120
3
200
0
210
0
220
0

狀態
次數
001
0
011
0
021
0
101
0
111
0
121
0
201
2
211
0
221
0

狀態
次數
002
0
012
3
022
0
102
0
112
0
122
0
202
0
212
0
222
0
6.      計算機率
(1)0
PiPi_total=3+2+3=83/Pi_total=0.375
Aij000102次數Aij_total2,故
             0           

(0/Aij_total=0) (2//Aij_total=1) (0//Aij_total=0)

Opdf: 000220Opdf_total=3
 (0)   (0)  (0)   (0)    (0)    (1)   (0)   (0)    (0)
000   010 020 100  110  120  200  210  220


(2)1
Pi: 0.25
Aij: 0 0 1
Opdf: Integer distribution --- 0 0 0 0 0 0 1 0 0
(3)2
Pi: 0.375
Aij: 1 0 0
Opdf: Integer distribution --- 0 1 0 0 0 0 0 0 0
7.      取最可能的下一狀態:從序列中[0120120120]可得windowsize2時,應該取的觀察值[0120120120]20,故從各矩陣機率來比較分別為
200=>0
201=>1(此為最可能狀態)
202=>0

程式實作:

一、首先到:https://code.google.com/p/jahmm/downloads/list抓取HMM原始碼套件,如此省去實作整個HMM模式的設計,僅須將資料丟入運算即可。
二、程式碼

import jahmmSourceCode.Hmm;
import jahmmSourceCode.ObservationInteger;
import jahmmSourceCode.OpdfInteger;
import jahmmSourceCode.OpdfIntegerFactory;
import java.io.*;
import java.util.*;
public class HMM_Model implements Serializable{
    private HashMap hsData = new HashMap();
    private StringBuilder sbPerm = new StringBuilder();
    private Hmm hmm;
    private final double eps = 1.0E-9;
    private final int states = 3;
    private String sequence;//存觀察值
    private int windowsize;
    private static final long serialVersionUID = 2L; 
    public HMM_Model(String sequence, int windowsize) {//建構
        this.windowsize = windowsize;
        this.sequence = sequence.substring(sequence.length()-windowsize);//全域sequence設成觀察植
        InitHash(sequence);//初始Hash
        getKeys();//排列組合
        BuildHmm();//建立模型
    }
    public int getStrategy(){//next Strategy
         int predict = getLikelyState(sequence);
         if (predict == 0) return 1;
         else if (predict == 1) return 2;
         else return 0;
    }
    public void addState(String State){
        sequence += State;
        HashPush(State);//State
        HashPush(sequence.substring(sequence.length() - 2));//TransferMatrix
        HashPush(sequence);//Observations
        BuildHmm();
        sequence = sequence.substring(sequence.length()-windowsize);//取最後的觀察值
    }
    private void BuildHmm(){
        Hmm hmm = new Hmm(3, new OpdfIntegerFactory(3));
        String sKeys[] = sbPerm.toString().split("\n");
        double d;
        int NumOfPerm = sKeys.length;
        for (int i = 0; i < states; i++) {
            d = eps;
            if (hsData.get(i + "") != null) d =  hsData.get(i + "");//Pi
            hmm.setPi(i, d);
            for (int j = 0; j < states; j++) {//狀態轉移矩陣
                d = eps;
                if (hsData.get(i+""+j) != null) d = hsData.get(i + "" + j);
                hmm.setAij(i, j, d);
            }
            double dPerm[] = new double[NumOfPerm];
            d = 0;
            for (int j = 0; j < sKeys.length; j++) {//觀察值矩陣
                 if (hsData.get(sKeys[j] + i) != null) dPerm[j] = hsData.get(sKeys[j] + i);
                 else dPerm[j] = eps;
                 d += dPerm[j];
            }
            for (int j = 0; j < dPerm.length; j++) {
                dPerm[j]/=d;
            }
            hmm.setOpdf(i, new OpdfInteger(dPerm));
        }
        normalize(hmm);
        this.hmm = hmm;
    }
    private void normalize(Hmm hmm) {
        double sum = 0;
        for (int i = 0; i < states; i++) {
            sum += hmm.getPi(i);
        }
        for (int i = 0; i < states; i++) {
            hmm.setPi(i, hmm.getPi(i) / sum);
        }
        for (int i = 0; i < states; i++) {
            sum = 0;
            for (int j = 0; j < states; j++) {
                sum += hmm.getAij(i, j);
            }
            for (int j = 0; j < states; j++) {
                hmm.setAij(i, j, hmm.getAij(i, j) / sum);
            }
        }
    }  
    private int getLikelyState(String sObservations){//predict
           String sKeys[] = sbPerm.toString().split("\n");    
           Arrays.sort(sKeys);
//           for (int i = 0; i < sKeys.length; i++) {
//               System.out.println(sKeys[i] + " ");
//            
//             }
           List test_seq = new ArrayList();
           
           ObservationInteger obs = new ObservationInteger(Arrays.binarySearch(sKeys, sObservations));
           test_seq.add(obs);
           int[] n = hmm.mostLikelyStateSequence(test_seq);
           return n[n.length-1];
    }
    
    private void InitHash(String sequence){
        for (int i = windowsize; i < sequence.length(); i++) {
            HashPush(sequence.substring(i,i+1));//加入狀態
            if (i < sequence.length() - 1)HashPush(sequence.substring(i, i+2));//加入狀態轉移矩陣
            HashPush(sequence.substring(i - windowsize,i + 1));//加入觀察矩陣
        }
    }
    private void HashPush(String s){
        if (hsData.containsKey(s)) hsData.put(s, hsData.get(s)+1);   
        else hsData.put(s, new Integer(1));
    }
    private void getKeys(){
        int a[] = new int[windowsize];
        perm(a,windowsize,0);
    }
    private void perm(int a[], int k, int t) {
        if (t == k) {
            String s = "";
            for (int i = 0; i < k; i++) {
                s += a[i];
            }
            sbPerm.append(s).append("\n");
        } else {
            for (int i = 0; i < states; i++) {
                a[t] = i;
                perm(a, k, t + 1);
            }
        }
    }
    public String getText(){return hmm.toString() + "\nSequence = " + sequence;}
    public static void main(String[] args) {
       String sequences = "0120120120";
        int windowsize = 2;
        HMM_Model rp = new HMM_Model(sequences,windowsize);
        System.out.println("sequence = " + sequences);
        System.out.println(rp.getText());
        System.out.println("predict = " + rp.getLikelyState(sequences.substring(sequences.length()-windowsize)));//預測值
        System.out.println("------------------------------------------------------------------------");
        sequences += "2";
        rp.addState("2");
        System.out.println("sequence = " + sequences);
        System.out.println(rp.getText());
        System.out.println("predict = " +rp.getLikelyState(sequences.substring(sequences.length()-windowsize)));
    }
}

三、預測結果

留言

這個網誌中的熱門文章

java西元民國轉換_各種不同格式

C#資料庫操作(新增、修改、刪除、查詢)