/*
Copyright 2009-2012 Andreas Biegert, Christof Angermueller
This file is part of the CS-BLAST package.
The CS-BLAST package is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
The CS-BLAST package is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
*/
#ifndef CS_CRF_STATE_H_
#define CS_CRF_STATE_H_
#include "profile_column.h"
#include "profile-inl.h"
#include "context_profile-inl.h"
#include "substitution_matrix.h"
namespace cs {
template
struct CrfState {
// Default construction
CrfState() : bias_weight(0.0) {};
// Constructs a CRF state with 'len' columns
explicit CrfState(size_t len)
: bias_weight(0.0), context_weights(len) {
assert(len & 1);
}
// Construction from serialized profile read from input stream.
explicit CrfState(FILE* fin) { Read(fin); }
// Constructs a CRF from a profile of probabilities.
CrfState(double prior, Profile prof, ProfileColumn col,
double weight_center = 1.0, double weight_decay = 1.0) {
Init(prior, prof, col, weight_center, weight_decay);
}
// Constructs a CRF from a ContextProfile.
CrfState(ContextProfile p, double weight_center = 1.6, double weight_decay = 0.85) {
if (p.is_log) TransformToLin(p);
Init(p.prior, p.probs, ProfileColumn(p.probs[(p.length() - 1) / 2]),
weight_center, weight_decay);
}
// Initializes the CRF with context weights based on the values in profile 'prof'
// and pseudocount weights based on values in profile column 'col' with
// prior probability 'prior'. We assume that all arguments are in lin space. The
// column weights are defined by wcenter and wdecay.
void Init(double prior, Profile prof, ProfileColumn col,
double weight_center = 1.0, double weight_decay = 1.0) {
assert(prof.length() & 1);
context_weights = Profile(prof.length());
Normalize(prof, 1.0);
Normalize(col, 1.0);
bias_weight = log(MAX(prior, DBL_MIN));
const size_t c = (length() - 1) / 2;
double weights[length()];
for (size_t i = 1; i <= c; ++i) {
double weight = weight_center * pow(weight_decay, i);
weights[c - i] = weight;
weights[c + i] = weight;
}
weights[c] = weight_center;
for (size_t j = 0; j < length(); ++j) {
for (size_t a = 0; a < Abc::kSize; ++a)
context_weights[j][a] = weights[j] * log(MAX(prof[j][a], DBL_MIN));
context_weights[j][Abc::kAny] = 0.0;
}
for (size_t a = 0; a < Abc::kSize; ++a)
pc_weights[a] = log(MAX(col[a], DBL_MIN));
pc_weights[Abc::kAny] = 0.0;
UpdatePseudocounts(*this);
}
// Initializes count profile with a serialized profile read from stream.
void Read(FILE* fin);
// Initializes count profile with a serialized profile read from stream.
void Write(FILE* fin) const;
// Returns number of context weights columns.
inline size_t length() const { return context_weights.length(); }
// Compares two CRF states.
bool operator< (const CrfState& other) const {
return bias_weight < other.bias_weight;
}
std::string name; // name of this state
double bias_weight; // bias weight lamda_k of this state
Profile context_weights; // context weights lamda_k(j,a)
ProfileColumn pc_weights; // unnormalized logs of pseudocounts
ProfileColumn pc; // predicted pseudocounts at central column
};
// Prints CRF state weights in human-readable format for debugging.
template
std::ostream& operator<< (std::ostream& out, const CrfState& crf);
// Updates pseudocount emission probs in given CRF state based on pc_weights.
template
void UpdatePseudocounts(CrfState& state);
// Calculates context score between a CRF state and a sequence window
template
double ContextScore(const Profile& context_weights,
const Sequence& seq,
size_t idx,
size_t center);
// Calculates context score between a CRF state and a count profile window
template
double ContextScore(const Profile& context_weights,
const CountProfile& cp,
size_t idx,
size_t center);
} // namespace cs
#endif // CS_CRF_STATE_H_