blob: 9eca2371d3095617e44fc73adbf9a85fb4e60eac [file] [log] [blame]
package org.eclipse.stem.fbd.util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.List;
import org.eclipse.stem.fbd.model.EmpiricalDataset;
import org.eclipse.stem.fbd.model.scenario.BatchExperimentScenario;
import org.eclipse.stem.fbd.prediction.MLSourcePredictor;
public class GenerateFoodSensitivitySpecificity {
//private static final double [] THRESHOLDS = {1E-25};
private static final double [] THRESHOLDS = {1.0/256, 1.0/128, 1.0/64, 1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0/2, 1.0};
//private static final double [] THRESHOLDS = {1.0};
static int guiltyProd;
static ArrayList<Integer> productFilter;
// Whether to drop products with zero (or close to zero) probabilities from selectivity analysis
static boolean DROP_ZERO_PROBABILITIES_FROM_SELECTIVITY = true;
static double ZERO_THRESHOLD = 1E-70; // should be small enough
/**
* @param args
*/
public static void main(String[] args) {
if(args.length < 2) {
System.out.println("Usage: GenerateFoodSensitivitySpecificity <Population Density File> <Food Distribution file> <Product filter file>");
System.exit(-1);
}
try {
EmpiricalDataset ed = EmpiricalDataset
.createDatasetFromFiles(
new File(args[0]),
new File(args[1]));
if(args.length == 3)
readFilterFile(args[2]);
FileWriter fw_sensitivity = null;
FileWriter fw_specificity = null;
FileWriter fw_skipped_tn = null;
if(args.length ==3) {
fw_sensitivity = new FileWriter("sensitivities_sz_"+productFilter.size()+".csv");
fw_specificity = new FileWriter("specificity_sz_"+productFilter.size()+".csv");
fw_skipped_tn = new FileWriter("skippedTN_sz_"+productFilter.size()+".csv");
} else {
fw_sensitivity = new FileWriter("sensitivities.csv");
fw_specificity = new FileWriter("specificity.csv");
fw_skipped_tn = new FileWriter("skippedTN.csv");
}
fw_sensitivity.write("threshold,");
fw_specificity.write("threshold,");
fw_skipped_tn.write("pid,");
for(int i=0;i<BatchExperimentScenario.MAX_NO_OF_REPORTS;++i) {
fw_sensitivity.write("report "+(i+1)+",");
fw_specificity.write("report "+(i+1)+",");
fw_skipped_tn.write("report "+(i+1)+",");
}
fw_sensitivity.write("\n");
fw_specificity.write("\n");
fw_skipped_tn.write("\n");
List<Double []> avg_sensitivities = new ArrayList<Double[]>();
List<Double []> avg_specificities = new ArrayList<Double[]>();
double [] skippedTruePositives = null;
// new double[BatchExperimentScenario.MAX_NO_OF_REPORTS];
//double [] avg_specificities = new double[BatchExperimentScenario.MAX_NO_OF_REPORTS];
for(guiltyProd=0;guiltyProd<ed.getFoodDistributionNames().length;++guiltyProd) {
//for(guiltyProd=0;guiltyProd<3;++guiltyProd) {
if(productFilter != null && !productFilter.contains(guiltyProd+1)) // +1 products goes from 1, ... in file
continue;
Object [] results = runBatch(ed, guiltyProd, 100, 0.0);
@SuppressWarnings("unchecked")
ArrayList<List<Integer>> res = (ArrayList<List<Integer>>)results[0];
ArrayList<List<List<Double>>> probs = (ArrayList<List<List<Double>>>)results[1];
for(int th_ind = 0;th_ind<THRESHOLDS.length;++th_ind) {
double threshold = THRESHOLDS[th_ind];
double [] sensitivities = getSensitivities(threshold, probs);
Object [] o = getSpecificities(threshold, probs);
double [] specificities = (double[])o[0];
skippedTruePositives = (double[])o[1]; // same for every threshold
if(avg_sensitivities.size() <= th_ind) {
Double [] d1 = new Double[BatchExperimentScenario.MAX_NO_OF_REPORTS];
Double [] d2 = new Double[BatchExperimentScenario.MAX_NO_OF_REPORTS];
for(int d=0;d<BatchExperimentScenario.MAX_NO_OF_REPORTS;++d) {
d1[d] = Double.valueOf(0);
d2[d] = Double.valueOf(0);
}
avg_sensitivities.add(d1);
avg_specificities.add(d2);
}
for(int i=0;i<sensitivities.length;++i) {
avg_sensitivities.get(th_ind)[i]=avg_sensitivities.get(th_ind)[i]+sensitivities[i];
avg_specificities.get(th_ind)[i]=avg_specificities.get(th_ind)[i]+specificities[i];
}
}
fw_skipped_tn.write(""+guiltyProd+",");
for(int caseR = 0;caseR < skippedTruePositives.length;++caseR)
if(caseR < skippedTruePositives.length-1)
fw_skipped_tn.write(skippedTruePositives[caseR]+",");
else
fw_skipped_tn.write(skippedTruePositives[caseR]+"");
fw_skipped_tn.write("\n");
System.out.println("Processing data for product: "+guiltyProd);
} // for all products
for(int th_ind=0;th_ind<THRESHOLDS.length;++th_ind) {
for(int i=0;i<avg_sensitivities.get(th_ind).length;++i) {
if(productFilter == null) {
avg_sensitivities.get(th_ind)[i]=avg_sensitivities.get(th_ind)[i]/ed.getFoodDistributionNames().length;
avg_specificities.get(th_ind)[i]=avg_specificities.get(th_ind)[i]/ed.getFoodDistributionNames().length;
} else {
avg_sensitivities.get(th_ind)[i]=avg_sensitivities.get(th_ind)[i]/productFilter.size();
avg_specificities.get(th_ind)[i]=avg_specificities.get(th_ind)[i]/productFilter.size();
}
//avg_sensitivities[i]=avg_sensitivities[i]/1;
//avg_specificities[i]=avg_specificities[i]/1;
}
}
for(int th_ind = 0;th_ind<THRESHOLDS.length;++th_ind) {
double threshold = THRESHOLDS[th_ind];
fw_sensitivity.write(""+1/threshold+",");
fw_specificity.write(""+1/threshold+",");
for(int i=0;i<avg_sensitivities.get(th_ind).length;++i) {
fw_sensitivity.write(""+avg_sensitivities.get(th_ind)[i]);
fw_specificity.write(""+avg_specificities.get(th_ind)[i]);
if(i < avg_sensitivities.get(th_ind).length-1) {
fw_sensitivity.write(",");
fw_specificity.write(",");
}
}
fw_sensitivity.write("\n");
fw_specificity.write("\n");
fw_sensitivity.flush();
fw_specificity.flush();
}
fw_sensitivity.close();
fw_specificity.close();
fw_skipped_tn.close();
} catch(Exception e) {
e.printStackTrace();
}
// Write AUC.csv file
}
private static void readFilterFile(String fname) throws Exception {
productFilter = new ArrayList<Integer>();
BufferedReader br = new BufferedReader(new FileReader(fname));
String pid_str = br.readLine();
while(pid_str != null) {
productFilter.add(Integer.parseInt(pid_str.trim()));
pid_str = br.readLine();
}
br.close();
}
static BatchExperimentScenario scenario;
public static Object[] runBatch(EmpiricalDataset ed, int contaminationSource, int repeats, double noise) {
ArrayList<List<Integer>> results = new ArrayList<List<Integer>>();
ArrayList<List<List<Double>>> relativeProbs = new ArrayList<List<List<Double>>>();
Object [] returnVal = new Object[2];
if(scenario == null)
scenario = new BatchExperimentScenario(ed);
try {
scenario.init(contaminationSource, contaminationSource);
for (int i = 0; i < repeats; i++) {
MLSourcePredictor pred = new MLSourcePredictor(
scenario.getFoodDistributions(), contaminationSource);
scenario.registerContaminationEventObserver(pred);
scenario.runExperiment(noise, i);
scenario.unregisterContaminationEventObserver(pred);
results.add(pred.getHistory());
ArrayList<List<Double>> runCaseReportProbs = new ArrayList<List<Double>>();
relativeProbs.add(runCaseReportProbs);
for(int j=0;j<BatchExperimentScenario.MAX_NO_OF_REPORTS;++j) {
List<Double> relativeProbForReports = null;
if(runCaseReportProbs.size() <=j) {
relativeProbForReports = new ArrayList<Double>();
runCaseReportProbs.add(relativeProbForReports);
} else relativeProbForReports = runCaseReportProbs.get(j);
double [] rp = pred.getRelativeProbabilityFor(j);
int p=0;
for(double d:rp) {
rp[p]=Math.exp(rp[p]);
++p;
}
p=0;
// if(relativeProbForReports.size() == 0)
// for(double d:rp) relativeProbForReports.add(d);
for(double d:rp) {
relativeProbForReports.add(d);
}
}
}
} catch (MagicTableException e) {
e.printStackTrace();
}
returnVal[0] = results;
returnVal[1] = relativeProbs;
return returnVal;
}
/**
* Compute sensitivity, i.e. the probability that a guilty product is correctly identified
* as such.
*
* Sensitivity = "true positives" divided by ("true positives" + "false negatives")
*
* True positives = Out of the 100 runs, how many times were the guilty product above the likelihood threshold
* False negatives = Out of the 100 runs, how many times were the guilty product below the likelihood threshold
*
* So denominator is just the total number of runs (100).
*
* @param threshold
* @param probabilities
* @return
*/
private static double[] getSensitivities(double threshold,
ArrayList<List<List<Double>>> probabilities) {
// process results
double[] res = new double[BatchExperimentScenario.MAX_NO_OF_REPORTS];
double noOfExp = (double) probabilities.size();
// Loop over each run
ArrayList<Double> sensitivityForCaseReport = new ArrayList<Double>();
for(int run=0;run<noOfExp;++run) {
for(int caseReports = 0;caseReports <probabilities.get(run).size();++caseReports) {
List<Double> runProbs = probabilities.get(run).get(caseReports);
// Find max
double max = -1;
for(int j=0;j<runProbs.size();++j)
//if(productFilter == null && runProbs.get(j) >max)
// max=runProbs.get(j);
//else if(productFilter != null && productFilter.contains(j) && runProbs.get(j) >max)
// max=runProbs.get(j);
if(runProbs.get(j) >max)
max=runProbs.get(j);
//size = runProbs.get(guiltyProd)/max;
double add = 1;
if(runProbs.get(guiltyProd) / max < threshold)
add = 0;
if(sensitivityForCaseReport.size()-1 < caseReports)
sensitivityForCaseReport.add(add);
else
sensitivityForCaseReport.set(caseReports, sensitivityForCaseReport.get(caseReports)+add);
}
}
for(int i=0;i<res.length;++i) {
res[i] = sensitivityForCaseReport.get(i)/noOfExp;
}
return res;
}
/**
* Compute specificity, i.e. the probability that a non-guilty product is correctly identified as such.
*
* Specificity = "true negatives" divided by ("true negatives" + "false positives")
*
* True negatives = the number of products below the likelihood threshold (that is not the guilty one)
* False positives = The number of products above the likelihood threshold (that is not the guilty one)
*
* Average specificity over the 100 runs
*
* @param threshold
* @param probabilities
* @return
*/
private static Object[] getSpecificities(double threshold,
ArrayList<List<List<Double>>> probabilities) {
// process results
Object [] result = new Object[2];
double[] res1 = new double[BatchExperimentScenario.MAX_NO_OF_REPORTS];
double [] res2 = new double[BatchExperimentScenario.MAX_NO_OF_REPORTS];
double noOfExp = (double) probabilities.size();
// Loop over each run
ArrayList<Double> specificityForCaseReport = new ArrayList<Double>();
ArrayList<Double> skippedTrueNegativesForCaseReport = new ArrayList<Double>();
for(int run=0;run<noOfExp;++run) {
for(int caseReports = 0;caseReports <probabilities.get(run).size();++caseReports) {
List<Double> runProbs = probabilities.get(run).get(caseReports);
// Find max
double max = -1;
for(int j=0;j<runProbs.size();++j)
if(runProbs.get(j) >max) max=runProbs.get(j);
double trueNegatives = 0.0;
double falsePositives = 0.0;
double skippedTrueNegatives = 0.0;
for(int j=0;j<runProbs.size();++j) {
// if(productFilter != null && !productFilter.contains(j))
// continue;
if(!DROP_ZERO_PROBABILITIES_FROM_SELECTIVITY && runProbs.get(j) / max < threshold && j != guiltyProd)
++trueNegatives;
else if(DROP_ZERO_PROBABILITIES_FROM_SELECTIVITY && runProbs.get(j) / max < threshold &&
runProbs.get(j) > ZERO_THRESHOLD && j != guiltyProd)
++trueNegatives;
else if(DROP_ZERO_PROBABILITIES_FROM_SELECTIVITY && runProbs.get(j) / max < threshold &&
runProbs.get(j) <= ZERO_THRESHOLD && j != guiltyProd)
++skippedTrueNegatives;
if(runProbs.get(j) / max >= threshold && j != guiltyProd)
++falsePositives;
}
double specificity = 0;
if(trueNegatives+falsePositives == 0.0)
specificity = 1.0; // No product is non-guilty, so perfect specificity
else
specificity = trueNegatives/(trueNegatives+falsePositives);
if(specificityForCaseReport.size()-1 < caseReports)
specificityForCaseReport.add(specificity);
else
specificityForCaseReport.set(caseReports, specificityForCaseReport.get(caseReports)+specificity);
if(skippedTrueNegativesForCaseReport.size()-1 < caseReports)
skippedTrueNegativesForCaseReport.add(skippedTrueNegatives);
else
skippedTrueNegativesForCaseReport.set(caseReports, skippedTrueNegativesForCaseReport.get(caseReports)+skippedTrueNegatives);
}
}
for(int i=0;i<res1.length;++i) {
res1[i] = specificityForCaseReport.get(i)/noOfExp;
res2[i] = skippedTrueNegativesForCaseReport.get(i)/noOfExp;
}
result[0] = res1;
result[1] = res2;
return result;
}
}