#!/usr/bin/perl -w

##############################################################################
#Evaluate the beta-residue pairing perforance for both Precision-Recall and TPR/FPR.
#F-measure, Break Even, .05 FPR. 
#Given a directory of predicted H-bond (more accuracy partner pair) matrix 
#Format: name, seq, ss, bp1, bp2, matrix
#Author: Jianlin Cheng
#Date: 10/20/2004
#############################################################################

if (@ARGV != 2)
{
	die "need 2 params: input matrix dir, output statistics file\n"; 
}

$input_dir = shift @ARGV;
$output_file = shift @ARGV;  

if ( substr($input_dir, length($input_dir) - 1, 1) ne "/" )
{
        $input_dir .= "/";
}

opendir(INPUT, "$input_dir") || die "can't open input dir.\n"; 
@files = readdir(INPUT);
closedir INPUT; 

open(OUTPUT, ">$output_file") || die "can't create output file.\n"; 

#thresholds array and confusion matrix
#$step = 100; 
$step = 200; 
for ($i = 0; $i <= $step; $i++)
{
	$thresholds[$i] = $i / $step; 	
	$TP[$i] = 0;
	$FN[$i] = 0;
	$FP[$i] = 0;
	$TN[$i] = 0; 
}

while(@files)
{
	$file = shift @files; 
	if ($file eq "." || $file eq "..")
	{
		next; 
	}

	$full_file = $input_dir . $file; 
	open(FILE, "$full_file") || die "can't read matrix file: $full_file\n"; 
	@content = <FILE>;
	close FILE; 

	$name = shift @content;
#	print $name; 
	chomp $name;
	$seq = shift @content;
#	print $seq; 
	chomp $seq;
	$length = length($seq);
	$ss = shift @content;
#	print $ss; 
	chomp $ss;
	$bp1 = shift @content;
#	print $bp1; 
	chomp $bp1;
	$bp2 = shift @content; 
	chomp $bp2;
	
	#generate target matrix
	for ($i = 0; $i < $length; $i++)
	{
		$ss_vec[$i] = substr($ss, $i, 1); 
		#$aa_vec[$i] = substr($seq, $i, 1); 
	}
	@bp1_vec = split(/\s+/, $bp1);
	@bp2_vec = split(/\s+/, $bp2); 
	if (@bp1_vec != $length || @bp2_vec != $length)
	{
		die "size of bp doesn't equal sequence length: $file\n"; 
	}

	@ee_idx = (); #index of beta residue (start from 1)
	@aa_idx = (); #position of beta residue in the sequence, start from 1
	%aa2ee = (); # a map from position to index 
	%ee2aa = (); # a map from index to position 

	$index = 0; 
	for ($i = 0; $i < $length; $i++)
	{
		if ($ss_vec[$i] eq "E" || $ss_vec[$i] eq "B")
		{
			$index++; 
			push @ee_idx, $index;	
			push @aa_idx, $i+1; 
			$aa2ee{$i+1} = $index; 
			$ee2aa{$index} = $i+1; 
		}
	}

	#generate the true inter-strand residue pair  matrix 
	@tmatrix = (); 
	for ($i = 0; $i < @ee_idx; $i++)
	{
		#get position of beta residue
		$aa1 = $ee2aa{$i+1}; 	
		#get its partners
		$h1 = $bp1_vec[$aa1 - 1];  
		$h2 = $bp2_vec[$aa1 - 1]; 

		#convert partner position to beta residue-index
		if (exists $aa2ee{$h1})
		{
			$e1 = $aa2ee{$h1}; 
			if ($e1 > @ee_idx)
			{
				die "ee index is bigger than the total number of ee\n"; 
			}
		}
		else
		{
			$e1 = -1; 
		}
		if ( exists $aa2ee{$h2} )
		{
			$e2 = $aa2ee{$h2}; 
			if ($e2 > @ee_idx)
			{
				die "ee index is bigger than the total number of ee\n"; 
			}
		}
		else
		{
			$e2 = -1; 
		}

		for ($j = 0; $j < @ee_idx; $j++)
		{
			if ($j + 1 == $e1 || $j + 1 ==  $e2)
			{
				$aa2 = $ee2aa{$j+1};	
				$b1 = $bp1_vec[$aa2-1];
				$b2 = $bp2_vec[$aa2-1];
				if ($b1 != $aa1 && $b2 != $aa1)
				{
					die "h-bond not consistent: aa1: $aa1 ($h1, $h2), aa2: $aa2 ($b1, $b2)\n";  
				}
				$tmatrix[$i][$j] = 1; 
			}
			else
			{
				$tmatrix[$i][$j] = 0; 
			}
		}
	}

	#read predicted probability matrix
	#get matrix size
	$size = @ee_idx; 
	if (@content != $size)
	{
		die "predicted matrix size doesn't equal the number of beta-residues.\n"; 
	}

	@pmatrix = ();
	for ($i = 0; $i < $size; $i++)
	{
		$line = shift @content; 
		chomp $line; 
		@probs = split(/\s+/, $line);
		for($j = 0; $j < $size; $j++)
		{
			$pmatrix[$i][$j] = $probs[$j]; 
		}
	}

	#generate strand information
	@strand_start = ();
	@strand_end = ();
	$in_strand = 0; 
	for ($i = 0; $i < $length; $i++)
	{
		$sec = $ss_vec[$i]; 
		#generate strand
		if ( $sec eq "E" || $sec eq "B")
		{
			if ($in_strand == 0)
			{
				push @strand_start, $i+1; 
				$in_strand = 1; 
			}

			if ($i == $length - 1) #at the end, just in case, shouldn't happen. 
			{
				push @strand_end, $i+1; 
			}
		}
		else
		{
			if ($in_strand == 1)
			{
				push @strand_end, $i; 
				$in_strand = 0; 
			}
		}
	}
	$num1 = @strand_start;
	$num2 = @strand_end; 
	if ($num1 != $num2)
	{
		die "strand num is not consistent.\n"; 
	}


	#compute the confusion matrix for all thresholds
	#Be careful: we only count the contacts between inter-strand pairs
	for ($i = 0; $i <= $step; $i++)
	{
		$thresh = $thresholds[$i];

		for ($m = 0; $m < $size; $m++)
		{
			$pos1 = $ee2aa{$m+1};  
			for ($n = $m+1; $n < $size; $n++)
			{
				$pos2 = $ee2aa{$n+1}; 
				#check if they are in the same strand 
				$inter = 1; 
				for ($k = 0; $k < @strand_start; $k++)
				{
					if ( ($strand_start[$k] <= $pos1 && $pos1 <= $strand_end[$k]) 
					&& ($strand_start[$k] <= $pos2 && $pos2 <= $strand_end[$k]) )
					{
						$inter = 0;
						last; 
					}
				}
				if ($inter == 0)
				{
					next; 
				}

				#count statistics
				$target = $tmatrix[$m][$n];
				$pre = $pmatrix[$m][$n];
				if ($pre > $thresh)
				{
					$pre = 1; 
				}
				else
				{
					$pre = 0; 
				}
				if ($target == 1)
				{
					if ($pre == 1)
					{
						$TP[$i]++; 
					}
					else
					{
						$FN[$i]++; 
					}
				}
				else
				{
					if ($pre == 0)
					{
						$TN[$i]++; 
					}
					else
					{
						$FP[$i]++; 
					}
				}
			}
		}
	}
}

print OUTPUT "tp <- c(", join(",", @TP), ")\n"; 
print OUTPUT "fn <- c(", join(",", @FN), ")\n"; 
print OUTPUT "tn <- c(", join(",", @TN), ")\n"; 
print OUTPUT "fp <- c(", join(",", @FP), ")\n"; 

@precision = ();
@recall = ();
@tpr = ();
@fpr = (); 

for ($i = 0; $i <= $step; $i++)
{
	if ($TP[$i] + $FP[$i] > 0)
	{
		$precision[$i] = $TP[$i] / ( $TP[$i] + $FP[$i] ); 
		$recall[$i] = $TP[$i] / ( $TP[$i] + $FN[$i] ); 
	}
	$tpr[$i] = $TP[$i] / ( $TP[$i] + $FN[$i] );
	$fpr[$i] = $FP[$i] / ($FP[$i] + $TN[$i]); 
}

@tpr = reverse @tpr;
@fpr = reverse @fpr; 

print OUTPUT "precision <- c(", join(",", @precision), ")\n";
print OUTPUT "recall <- c(", join(",", @recall), ")\n";
print OUTPUT "tpr <- c(", join(",", @tpr), ")\n";
print OUTPUT "fpr <- c(", join(",", @fpr), ")\n";

close OUTPUT; 






