blob: 4be911df2f84fa11a15745e5a90b2cbca70fa38d [file] [log] [blame]
package org.eclipse.stem.analysis.impl;
/*******************************************************************************
* Copyright (c) 2009 IBM Corporation and others.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* IBM Corporation - initial API and implementation
*******************************************************************************/
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.eclipse.emf.common.util.BasicEList;
import org.eclipse.emf.common.util.EList;
import org.eclipse.emf.ecore.EClass;
import org.eclipse.stem.analysis.AnalysisFactory;
import org.eclipse.stem.analysis.AnalysisPackage;
import org.eclipse.stem.analysis.ErrorResult;
import org.eclipse.stem.analysis.ReferenceScenarioDataMap;
import org.eclipse.stem.analysis.SimpleErrorFunction;
import org.eclipse.stem.analysis.impl.ReferenceScenarioDataMapImpl.ReferenceScenarioDataInstance;
/**
* <!-- begin-user-doc -->
* An implementation of the model object '<em><b>Simple Error Function</b></em>'.
* <!-- end-user-doc -->
* <p>
* </p>
*
* @generated
*/
public class SimpleErrorFunctionImpl extends ErrorFunctionImpl implements SimpleErrorFunction {
/**
* <!-- begin-user-doc -->
* <!-- end-user-doc -->
* @generated NOT
*/
public SimpleErrorFunctionImpl() {
super();
}
/**
* <!-- begin-user-doc -->
* <!-- end-user-doc -->
* @generated
*/
@Override
protected EClass eStaticClass() {
return AnalysisPackage.Literals.SIMPLE_ERROR_FUNCTION;
}
/**
* input
*/
Map<String,List<Double>> commonInfectiousLocationsA = new HashMap<String,List<Double>>();
Map<String,List<Double>> commonInfectiousLocationsB = new HashMap<String,List<Double>>();
Map<String,List<Double>> commonPopulationLocationsA = new HashMap<String,List<Double>>();
Map<String,List<Double>> commonPopulationLocationsB = new HashMap<String,List<Double>>();
Map<String,Double> commonAvgPopulationLocationsA = new HashMap<String,Double>();
Map<String,Double> commonAvgPopulationLocationsB = new HashMap<String,Double>();
Map<String, Double> commonMaxLocationsA = new HashMap<String, Double>();
/**
* number common locations with nonzero Inf count at time t
*/
public double[] locationCount;
/**
* The Result
*/
public double[] meanSqDiff;
/*
* time
*/
public double[] time;
protected AnalysisFactory aFactory = new AnalysisFactoryImpl();
// Set to true to weight the average by population size
private static boolean AGGREGATE_NRMSE = true; // True if aggregate signal for locations first, then calculate NRMSE. False if use NRMSE per location then average.
private static boolean WEIGHTED_AVERAGE = true; // Only used if AGGREGATE_NRMSE = false;
private static boolean FIT_INCIDENCE = true;
private static boolean USE_THRESHOLD = false;
private static double THRESHOLD = 0.05;
// The year to use to validate and hence exclude from the error calculation(for cross-validation methods)
// The first year is year 0. If no year should be excluded, set to -1
int validationYear = -1;
/**
* calculate delta for a simple error function
*
*
* @override
*/
@Override
public ErrorResult calculateError(ReferenceScenarioDataMap reference, ReferenceScenarioDataMap data) {
final ReferenceScenarioDataMapImpl _ref = (ReferenceScenarioDataMapImpl)reference;
final ReferenceScenarioDataMapImpl _data = (ReferenceScenarioDataMapImpl)data;
// clear
time = null;
Iterator<String> iteratorA = _ref.getLocations().iterator();
int maxTime = -1;
while(iteratorA.hasNext()) {
String id = iteratorA.next();
if(_data.containsLocation(id)) {
// get the lists of data only for those locations that are common to both maps ReferenceScenarioDataInstance dataMapA = mapA.getLocation(id);
ReferenceScenarioDataInstance dataMapA = _ref.getLocation(id);
List<Double> dataAI = null;
if(FIT_INCIDENCE) dataAI = getIncidence(dataMapA);
else dataAI = getInfectious(dataMapA);
List<Double> dataAP = getPopulation(dataMapA);
commonInfectiousLocationsA.put(id,dataAI);
commonPopulationLocationsA.put(id, dataAP);
// Map B
ReferenceScenarioDataInstance dataMapB = _data.getLocation(id);
List<Double> dataBI = null;
if(FIT_INCIDENCE) dataBI = getIncidence(dataMapB);
else dataBI = getInfectious(dataMapB);
List<Double> dataBP = getPopulation(dataMapB);
commonInfectiousLocationsB.put(id,dataBI);
commonPopulationLocationsB.put(id, dataBP);
// init the array size
if (maxTime == -1) maxTime = dataAI.size();
// dimension the arrays to the length of the SMALLEST array for which we have data
if(maxTime >= dataBI.size() ) maxTime = dataBI.size();
if(maxTime >= dataAI.size() ) maxTime = dataAI.size();
}// if
}// while
if(maxTime<=0) maxTime = 0;
if(time==null) {
time = new double[maxTime];
meanSqDiff = new double[maxTime];
locationCount = new double[maxTime];
for(int i = 0; i < maxTime; i ++) {
time[i] = i;
meanSqDiff[i] = 0.0;
locationCount[i] = 0.0;
}
}
// Now figure out the actual error
double [] Xref = new double[time.length];
double [] Xdata = new double[time.length];
double finalerror = 0.0;
double verror = 0.0;
BasicEList<Double> list = new BasicEList<Double>();
for(int i=0;i<time.length;++i)list.add(0.0);
// Get the average population for each location
for(String loc:commonPopulationLocationsA.keySet()) {
List<Double>ld = commonPopulationLocationsA.get(loc);
double sum = 0;for(double d:ld)sum+=d;
sum /= (double)ld.size();
commonAvgPopulationLocationsA.put(loc, sum);
}
// Get the average population for each location
for(String loc:commonPopulationLocationsB.keySet()) {
List<Double>ld = commonPopulationLocationsB.get(loc);
double sum = 0;for(double d:ld)sum+=d;
sum /= (double)ld.size();
commonAvgPopulationLocationsB.put(loc, sum);
}
// Get the maximum value for the A series (reference)
for(String loc:commonPopulationLocationsA.keySet()) {
List<Double>ld = commonInfectiousLocationsA.get(loc);
double max = Double.MIN_VALUE;
for(double d:ld)if(d >max)max=d;
commonMaxLocationsA.put(loc, max);
}
// Calculate the normalized root mean square error for each location, then
// divide by the number of locatins
double weighted_denom = 0.0;
if(!AGGREGATE_NRMSE) { // Use NRMSE per location first
for(String loc:commonInfectiousLocationsA.keySet()) {
double maxRef = 0.0;
double minRef = Double.MAX_VALUE;
// Get the numbers at each time step for the location
for(int icount =0; icount < time.length; icount ++) {
List<Double> dataAI = commonInfectiousLocationsA.get(loc);
List<Double> dataBI = commonInfectiousLocationsB.get(loc);
double iA = dataAI.get(icount).doubleValue();
double iB = dataBI.get(icount).doubleValue();
Xref[icount]=iA;
Xdata[icount]=iB;
}
double nominator = 0.0;
double timesteps = 0;
for(int icount =0; icount < time.length; icount ++) {
if(Xref[icount]>maxRef)maxRef = Xref[icount];
if(Xref[icount]<minRef)minRef = Xref[icount];
// If we use the threshold and both the reference and the model is less than
// the THRESHOLD*MAXref(loc) we don't measure the data point
if(USE_THRESHOLD && (Xref[icount]<=THRESHOLD*commonMaxLocationsA.get(loc) &&
Xdata[icount]<=THRESHOLD*commonMaxLocationsA.get(loc))) continue;
nominator = nominator + Math.pow(Xref[icount]-Xdata[icount], 2);
list.set(icount, list.get(icount)+Math.abs(Xref[icount]-Xdata[icount]));
++timesteps;
}
double error = Double.MAX_VALUE;
if(timesteps > 0 && maxRef-minRef > 0.0) {
error = Math.sqrt(nominator/timesteps);
error = error / (maxRef-minRef);
if(WEIGHTED_AVERAGE) finalerror += commonAvgPopulationLocationsA.get(loc) * error;
else finalerror += error;
if(WEIGHTED_AVERAGE) weighted_denom += commonAvgPopulationLocationsA.get(loc);
else weighted_denom += 1.0;
}
}
// Divide the error by the number of locations
finalerror /= weighted_denom;
} else { // Aggregate signal, then calculate NRMSE
for(int icount =0; icount < time.length; icount ++) {
for(String loc:commonInfectiousLocationsA.keySet()) {
List<Double> dataAI = commonInfectiousLocationsA.get(loc);
List<Double> dataBI = commonInfectiousLocationsB.get(loc);
double iA = dataAI.get(icount).doubleValue();
double iB = dataBI.get(icount).doubleValue();
Xref[icount]+=iA;
Xdata[icount]+=iB;
}
}
double maxRef = Double.MIN_VALUE;
double minRef = Double.MAX_VALUE;
double maxValidationRef = Double.MIN_VALUE;
double minValidationRef = Double.MAX_VALUE;
for(int icount =0; icount < time.length; icount ++) {
if(icount >= validationYear*365.25 && icount <= (validationYear+1)*365.25) {
if(Xref[icount]>maxValidationRef)maxValidationRef = Xref[icount];
if(Xref[icount]<minValidationRef)minValidationRef = Xref[icount];
continue;
}
if(Xref[icount]>maxRef)maxRef = Xref[icount];
if(Xref[icount]<minRef)minRef = Xref[icount];
}
double nominator = 0.0, vnominator = 0.0;
double timesteps = 0.0, vtimesteps = 0.0;
for(int icount =0; icount < time.length; icount ++) {
// Calculate validation error then skip
if(icount >= validationYear*365.25 && icount <= (validationYear+1)*365.25) {
if(USE_THRESHOLD && (Xref[icount]<=THRESHOLD*maxValidationRef &&
Xdata[icount]<=THRESHOLD*maxValidationRef)) continue;
vnominator = vnominator + Math.pow(Xref[icount]-Xdata[icount], 2);
list.set(icount, new Double(0)); // Set to 0 for validation data points
++vtimesteps;
continue;
}
// If we use the threshold and both the reference and the model is less than
// the THRESHOLD*MAXref(loc) we don't measure the data point
if(USE_THRESHOLD && (Xref[icount]<=THRESHOLD*maxRef &&
Xdata[icount]<=THRESHOLD*maxRef)) continue;
nominator = nominator + Math.pow(Xref[icount]-Xdata[icount], 2);
list.set(icount, Math.abs(Xref[icount]-Xdata[icount]));
++timesteps;
}
double error = Double.MAX_VALUE;
if(timesteps > 0 && maxRef-minRef > 0.0) {
error = Math.sqrt(nominator/timesteps);
finalerror = error / (maxRef-minRef);
}
// Validation
error = Double.MAX_VALUE;
if(vtimesteps > 0 && maxValidationRef-minValidationRef > 0.0) {
error = Math.sqrt(vnominator/vtimesteps);
verror = error / (maxValidationRef-minValidationRef);
}
} // else
ErrorResult resultobj = aFactory.createErrorResult();
resultobj.setErrorByTimeStep(list);
resultobj.setError(finalerror);
resultobj.setValidationError(verror);
EList<Double>refByTime = new BasicEList<Double>();
EList<Double>dataByTime = new BasicEList<Double>();
// Set the reference and model by time
for(int icount=0;icount<time.length;++icount) {
refByTime.add(0.0);dataByTime.add(0.0);}
for(String loc:commonInfectiousLocationsA.keySet()) {
for(int icount =0; icount < time.length; icount ++) {
List<Double> dataAI = commonInfectiousLocationsA.get(loc);
List<Double> dataBI = commonInfectiousLocationsB.get(loc);
double iA = dataAI.get(icount).doubleValue();
double iB = dataBI.get(icount).doubleValue();
refByTime.set(icount, refByTime.get(icount)+iA);
dataByTime.set(icount, dataByTime.get(icount)+iB);
}
}
resultobj.setReferenceByTime(refByTime);
resultobj.setModelByTime(dataByTime);
return resultobj;
}
} //SimpleErrorFunctionImpl