diff --git a/prob.py b/prob.py index 4f35c701fd9c57a976e834e398e80a514061a7ae..cf4a36057bf48c11d6787512d3edd0cdc216b80e 100644 --- a/prob.py +++ b/prob.py @@ -1,5 +1,6 @@ ''' Module for classes and functions that are representing and processing basic probabilities. +Also includes Markov chain and hidden Markov model. Uses and depends on "Alphabet" that is used to define discrete random variables. ''' import random @@ -675,6 +676,10 @@ class IndepJoint(Joint): return sorted(ret, key=lambda v: v, reverse=True) return ret +################################################################################################# +# Naive Bayes' classifier +################################################################################################# + class NaiveBayes(): """ NaiveBayes implements a classifier: a model defined over a class variable and conditional on a list of discrete feature variables. @@ -719,26 +724,29 @@ class NaiveBayes(): out.observe(outsym, prob) return out +################################################################################################# +# Markov chain +################################################################################################# + class MarkovChain(): - """ Markov Chain in a very simple form + """ Markov Chain in a simple form; supports higher-orders and can determine (joint) probability of sequence. """ - def __init__(self, alpha, order = 1): + def __init__(self, alpha, order = 1, startsym = '^', endsym = '\$'): + """ Construct a new Markov chain based on a given alphabet of characters. + alpha: alphabet of allowed characters and states + order: the number of states to include in memory (default is 1) + startsym: the symbol to mark the first character in the internal sequence, and the first state + endsym: the symbol to mark the termination of the internal sequence (and the last state) + """ self.order = order - self.startsym = '^' - self.endsym = '\$' - self.alpha = Alphabet(alpha.symbols + tuple([self.startsym, self.endsym])) + self.startsym = startsym + self.endsym = endsym + self.alpha = getTerminatedAlphabet(alpha, self.startsym, self.endsym) self.transit = TupleStore([self.alpha for _ in range(order)]) # transition probs, i.e. given key (prev state/s) what is the prob of current state - def _terminate(self, unterm_seq): - """ Terminate sequence with start and end symbols """ - term_seq = [self.startsym for _ in range(self.order)] - term_seq.extend(unterm_seq) - term_seq.append(self.endsym) - return term_seq - def _getpairs(self, term_seq): - """ Return a tuple of all (tuple) Markov pairs from a sequence """ + """ Return a tuple of all (tuple) Markov pairs from a sequence. Used internally. """ ret = [] for i in range(len(term_seq) - self.order): past = tuple(term_seq[i:i + self.order]) @@ -747,8 +755,10 @@ class MarkovChain(): return tuple(ret) def observe(self, wholeseq): - """ Set parameters by counting transitions """ - myseq = self._terminate(wholeseq) + """ Set parameters of Markov chain by counting transitions, as observed in the sequence. + wholeseq: the sequence not including the termination symbols. + """ + myseq = _terminate(wholeseq, self.order, self.startsym, self.endsym) for (past, present) in self._getpairs(myseq): d = self.transit[past] if not d: # no distrib @@ -757,8 +767,11 @@ class MarkovChain(): d.observe(present) def __getitem__(self, wholeseq): - """ Determine the log probability of a given sequence """ - myseq = self._terminate(wholeseq) + """ Determine the log probability of a given sequence. + wholeseq: the sequence not including the termination symbols. + returns the joint probability + """ + myseq = _terminate(wholeseq, self.order, self.startsym, self.endsym) logp = 0 for (past, present) in self._getpairs(myseq): d = self.transit[past] @@ -768,4 +781,173 @@ class MarkovChain(): if p == 0: return None logp += math.log(p) - return logp \ No newline at end of file + return logp + + +def _terminate(unterm_seq, order = 1, startsym = '^', endsym = '\$'): + """ Terminate sequence with start and end symbols """ + term_seq = [startsym for _ in range(order)] + term_seq.extend(unterm_seq) + term_seq.append(endsym) + return term_seq + +def getTerminatedAlphabet(alpha, startsym = '^', endsym = '\$'): + """ Amend the given alphabet with termination symbols """ + return Alphabet(alpha.symbols + tuple([startsym, endsym])) + +################################################################################################# +# Hidden Markov model (HMM) +################################################################################################# + +class HMM(): + """ Basic, first-order HMM. + Has functionality to set up HMM, and query it with Viterbi and Forward algorithms.""" + + def __init__(self, states, symbols, startstate = '^', endstate = '\$'): + """ Construct HMM with states and symbols, here given as strings of characters. + > cpg_hmm = prob.HMM('HL','ACGT') + """ + if isinstance(states, str): + states = Alphabet(states) + self.mystates = getTerminatedAlphabet(states, startstate, endstate) + if isinstance(symbols, str): + symbols = Alphabet(symbols) + self.mysymbols = getTerminatedAlphabet(symbols, startstate, endstate) + self.a = dict() # transition probabilities + self.e = dict() # emission probabilities + self.startsym = startstate + self.endsym = endstate + + def transition(self, fromstate, distrib): + """ Add a transition to the HMM, determining with the probability of transitions, e.g. + > cpg_hmm.transition('^',{'^':0,'\$':0,'H':0.5,'L':0.5}) + > cpg_hmm.transition('H',{'^':0,'\$':0.001,'H':0.5,'L':0.5}) + > cpg_hmm.transition('L',{'^':0,'\$':0.001,'H':0.4,'L':0.6}) + > cpg_hmm.transition('\$',{'^':1,'\$':0,'H':0,'L':0}) + """ + if not isinstance(distrib, Distrib): + distrib = Distrib(self.mystates, distrib) + self.a[fromstate] = distrib + + def emission(self, state, distrib): + """ Add an emission probability to the HMM, e.g. + > cpg_hmm.emission('^',{'^':1,'\$':0,'A':0,'C':0,'G':0,'T':0}) + > cpg_hmm.emission('H',{'^':0,'\$':0,'A':0.2,'C':0.3,'G':0.3,'T':0.2}) + > cpg_hmm.emission('L',{'^':0,'\$':0,'A':0.3,'C':0.2,'G':0.2,'T':0.3}) + > cpg_hmm.emission('\$',{'^':0,'\$':1,'A':0,'C':0,'G':0,'T':0}) + """ + if not isinstance(distrib, Distrib): + distrib = Distrib(self.mysymbols, distrib) + self.e[state] = distrib + + def joint(self, symseq, stateseq): + """ + Determine the joint probability of the sequence and the given path. + :param symseq: sequence of characters + :param stateseq: sequence of states + :return: the probability + """ + X = _terminate(symseq, 1, self.startsym, self.endsym) + P = _terminate(stateseq, 1, self.startsym, self.endsym) + p = 1 + for i in range(len(X) - 1): + p = p * self.e[P[i]][X[i]] * self.a[P[i]][P[i + 1]] + return p + + def viterbi(self, symseq, V = dict(), trace = dict()): + """ + Determine the Viterbi path (the most probable sequence of states) given a sequence of symbols + :param symseq: sequence of symbols + :param V: the Viterbi dynamic programming variable as a matrix (optional; pass an empty dict if you need it) + :param trace: the traceback (optional; pass an empty dict if you need it) + :return: the Viterbi path as a string of characters + > X = 'GGCACTGAA' # sequence of characters + > states = cpg_hmm.viterbi(X) + > print(states) + """ + X = _terminate(symseq, 1, self.startsym, self.endsym) + # Initialise state scores for each index in X + for state in self.mystates: + # Fill in emission probabilities for each index in X + V[state] = [self.e[state][x] for x in X] + trace[state] = [] + for j in range(len(X) - 1): + i = j + 1 # sequence index that we're processing + for tostate in self.mystates: + tracemax = 0 + beststate = None + for fromstate in self.mystates: + score = V[fromstate][i - 1] * self.a[fromstate][tostate] + if score > tracemax: + beststate = fromstate + tracemax = score + trace[tostate].append(beststate) + V[tostate][i] = self.e[tostate][X[i]] * tracemax + ret = '' + traced = '\$' + for j in range(len(X)): + i = len(X) - 2 - j + traced = trace[traced][i] + if j > 0 and j < len(X) - 2: + ret = traced + ret + return ret + + def forward(self, symseq, F = dict()): + """ + Determine the probability of the sequence, summing over all possible state paths + :param symseq: sequence of symbols + :param F: the Forward dynamic programming variable as a matrix (optional; pass an empty dict if you need it) + :return: the probability + > X = 'GGCACTGAA' # sequence of characters + > prob = cpg_hmm.forward(X) + > print(prob) + """ + X = _terminate(symseq, 1, self.startsym, self.endsym) + # Initialise state scores for each index in X + for state in self.mystates: + # Fill in emission probabilities for each index in X + F[state] = [self.e[state][x] for x in X] + for j in range(len(X) - 1): + i = j + 1 # sequence index that we're processing + for tostate in self.mystates: + mysum = 0 + for fromstate in self.mystates: + mysum += F[fromstate][i - 1] * self.a[fromstate][tostate] + F[tostate][i] = self.e[tostate][X[i]] * mysum + traced = '\$' + return F[traced][len(X) - 1] + + def writeHTML(self, X, Viterbi, Trace = None, filename = None): + """ Generate HTML that displays a DP matrix from Viterbi (or Forward) algorithms. + > from IPython.core.display import HTML + > X = 'GGCACTGAA' # sequence of characters + > V = dict() + > T = dict() + > cpg_hmm.viterbi(X, V, T) + > HTML(cpg_hmm.writeHTML(X, V, T)) + """ + html = '''\nHMM dynamic programming matrix\n
```\n'''
+        html += '\n'
+        html += '\n'
+        html += '\n'
+        for state in Viterbi:
+            html += '\n' % str(state)
+        html += '\n'
+        # process each sequence symbol
+        X = _terminate(X, 1, self.startsym, self.endsym)
+        for row in range(len(X)):
+            html += '\n'
+            html += '\n' % (row, str(X[row]))
+            for state in Viterbi:
+                if Trace and row > 0:
+                    html += '\n' % (str(state),X[row],self.e[state][X[row]],str(state),Viterbi[state][row],Trace[state][row - 1],self.a[Trace[state][row - 1]][state] if Trace[state][row - 1] != None else 0)
+                else:
+                    html += '\n' % (str(state),X[row],self.e[state][X[row]],str(state),Viterbi[state][row])
+            html += '\n'
+        html += 'X%sx%d=%se%s(%s)=%4.2fV%s=%3.1e↑%s[%4.2f]e%s(%s)=%4.2fV%s=%3.1e\n'
+        html += '```
' + if filename: + fh = open(filename, 'w') + fh.write(html) + fh.close() + return html