/* Copyright 2009-2012 Andreas Biegert, Christof Angermueller This file is part of the CS-BLAST package. The CS-BLAST package is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. The CS-BLAST package is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ #ifndef CS_CONTEXT_WEIGHT_STATE_INL_H_ #define CS_CONTEXT_WEIGHT_STATE_INL_H_ #include "crf_state.h" namespace cs { template void CrfState::Read(FILE* fin) { // Parse and check header information if (!StreamStartsWith(fin, "CrfState")) throw Exception("Stream does not start with class id 'CrfState'!"); char buffer[KB]; cs::fgetline(buffer, KB, fin); if (strstr(buffer, "NAME")) { name = ReadString(buffer, "NAME", "Unable to parse CRF state 'NAME'!"); cs::fgetline(buffer, KB, fin); } bias_weight = ReadDouble(buffer, "BIAS", "Unable to parse CRF state 'BIAS'!"); cs::fgetline(buffer, KB, fin); size_t len = ReadInt(buffer, "LENG", "Unable to parse CRF state 'LENG'!"); cs::fgetline(buffer, KB, fin); size_t nalph = ReadInt(buffer, "ALPH", "Unable to parse CRF state 'ALPH'!"); assert(len & 1); if (nalph != Abc::kSize) throw Exception("Alphabet size of serialized CRF state should be %d" "but is actually %d!", Abc::kSize, nalph); // If everything went fine we can resize our data memmbers context_weights.Resize(len); // Read context weights and pseudocount weights const char* ptr = buffer; size_t i = 0; cs::fgetline(buffer, KB, fin); // skip alphabet description line while (cs::fgetline(buffer, KB, fin) && buffer[0] != '/' && buffer[1] != '/') { ptr = buffer; if (buffer[0] != 'P' && buffer[1] != 'C') { i = strtoi(ptr) - 1; assert(i < len); // TODO: include ANY char in serialization for (size_t a = 0; a < Abc::kSize; ++a) context_weights[i][a] = static_cast(strastoi(ptr)) / kScale; context_weights[i][Abc::kAny] = 0.0; } else { for (size_t a = 0; a < Abc::kSize; ++a) pc_weights[a] = static_cast(strastoi(ptr)) / kScale; } context_weights[i][Abc::kAny] = 0.0; pc_weights[Abc::kAny] = 0.0; } if (i != len - 1) throw Exception("CRF state should have %i columns but actually has %i!", len, i+1); UpdatePseudocounts(*this); } template void CrfState::Write(FILE* fout) const { // Print header section fputs("CrfState\n", fout); if (!name.empty()) fprintf(fout, "NAME\t%s\n", name.c_str()); fprintf(fout, "BIAS\t%-10.8g\n", bias_weight); fprintf(fout, "LENG\t%d\n", static_cast(context_weights.length())); fprintf(fout, "ALPH\t%d\n", static_cast(Abc::kSize)); // Print alphabet description line fputs("WEIGHTS", fout); for (size_t a = 0; a < Abc::kSize; ++a) fprintf(fout, "\t%c", Abc::kIntToChar[a]); fputs("\n", fout); // Print context weights scaled by 'kScale' for (size_t i = 0; i < context_weights.length(); ++i) { fprintf(fout, "%i", static_cast(i+1)); // TODO: include ANY char in serialization for (size_t a = 0; a < Abc::kSize; ++a) { if (context_weights[i][a] == -INFINITY) fputs("\t*", fout); else fprintf(fout, "\t%i", iround(context_weights[i][a] * kScale)); } fputs("\n", fout); } // Print pseudocount weights fputs("PC", fout); for (size_t a = 0; a < Abc::kSize; ++a) fprintf(fout, "\t%i", iround(pc_weights[a] * kScale)); fputs("\n//\n", fout); } // Prints CRF state weights probabilities in human-readable format for debugging. template std::ostream& operator<< (std::ostream& out, const CrfState& state) { out << "CrfState" << std::endl; out << "name:\t" << state.name << std::endl; out << "bias:\t" << strprintf("%-10.8g", state.bias_weight) << std::endl; const int c = (state.context_weights.length() - 1) / 2; for (size_t i = 0; i < state.context_weights.length(); ++i) out << "\t " << abs(static_cast(i) - c); out << "\t PCW\t PC" << std::endl; for (size_t a = 0; a < Abc::kSizeAny; ++a) { out << Abc::kIntToChar[a]; for (size_t i = 0; i < state.context_weights.length(); ++i) { out << strprintf("\t%+6.2f", state.context_weights[i][a]); } out << strprintf("\t%+6.2f\t%6.4f", state.pc_weights[a], state.pc[a]) << std::endl; } return out; } // Updates pseudocount emission probs in given CRF state based on 'pc_weights' and // rescales 'pc_weights' template inline void UpdatePseudocounts(CrfState& state) { // Calculate maximum of pseudocount weights double max = -DBL_MAX; double mean = 0.0; for (size_t a = 0; a < Abc::kSize; ++a) { mean += state.pc_weights[a]; if (state.pc_weights[a] > max) max = state.pc_weights[a]; } mean /= Abc::kSize; // Rescale pseudocount weights and calculate their sum in lin-space long double sum = 0.0; for (size_t a = 0; a < Abc::kSize; ++a) sum += exp(state.pc_weights[a] - max); // Update emission pseudocount vector double tmp = max + log(sum); for (size_t a = 0; a < Abc::kSize; ++a) { state.pc[a] = DBL_MIN + exp(state.pc_weights[a] - tmp); // state.pc_weights[a] -= mean; // Not necessary if pc_weights are centered on central context weights } } // Calculates context score between a CRF state and a sequence window template inline double ContextScore(const Profile& context_weights, const Sequence& seq, size_t idx, size_t center) { assert(context_weights.length() & 1); const size_t beg = MAX(0, static_cast(idx - center)); const size_t end = MIN(seq.length(), idx + center + 1); double score = 0.0; for(size_t i = beg, j = beg - idx + center; i < end; ++i, ++j) score += context_weights[j][seq[i]]; return score; } // Calculates context score between a CRF state and a count profile window template inline double ContextScore(const Profile& context_weights, const CountProfile& cp, size_t idx, size_t center) { assert(context_weights.length() & 1); const size_t beg = MAX(0, static_cast(idx - center)); const size_t end = MIN(cp.counts.length(), idx + center + 1); double score = 0.0; for(size_t i = beg, j = beg - idx + center; i < end; ++i, ++j) { for (size_t a = 0; a < Abc::kSize; ++a) score += context_weights[j][a] * cp.counts[i][a]; } return score; } } // namespace cs #endif // CS_CRF_STATE_INL_H_