| package org.eclipse.stem.analysis.impl; |
| |
| /******************************************************************************* |
| * Copyright (c) 2009 IBM Corporation and others. |
| * All rights reserved. This program and the accompanying materials |
| * are made available under the terms of the Eclipse Public License v1.0 |
| * which accompanies this distribution, and is available at |
| * http://www.eclipse.org/legal/epl-v10.html |
| * |
| * Contributors: |
| * IBM Corporation - initial API and implementation |
| *******************************************************************************/ |
| |
| import java.util.HashMap; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Map; |
| |
| import org.eclipse.emf.common.util.BasicEList; |
| import org.eclipse.emf.common.util.EList; |
| import org.eclipse.emf.ecore.EClass; |
| import org.eclipse.stem.analysis.AnalysisFactory; |
| import org.eclipse.stem.analysis.AnalysisPackage; |
| import org.eclipse.stem.analysis.ErrorResult; |
| import org.eclipse.stem.analysis.ReferenceScenarioDataMap; |
| import org.eclipse.stem.analysis.SimpleErrorFunction; |
| import org.eclipse.stem.analysis.impl.ReferenceScenarioDataMapImpl.ReferenceScenarioDataInstance; |
| |
| /** |
| * <!-- begin-user-doc --> |
| * An implementation of the model object '<em><b>Simple Error Function</b></em>'. |
| * <!-- end-user-doc --> |
| * <p> |
| * </p> |
| * |
| * @generated |
| */ |
| public class SimpleErrorFunctionImpl extends ErrorFunctionImpl implements SimpleErrorFunction { |
| /** |
| * <!-- begin-user-doc --> |
| * <!-- end-user-doc --> |
| * @generated NOT |
| */ |
| public SimpleErrorFunctionImpl() { |
| super(); |
| } |
| |
| /** |
| * <!-- begin-user-doc --> |
| * <!-- end-user-doc --> |
| * @generated |
| */ |
| @Override |
| protected EClass eStaticClass() { |
| return AnalysisPackage.Literals.SIMPLE_ERROR_FUNCTION; |
| } |
| |
| /** |
| * input |
| */ |
| Map<String,List<Double>> commonInfectiousLocationsA = new HashMap<String,List<Double>>(); |
| Map<String,List<Double>> commonInfectiousLocationsB = new HashMap<String,List<Double>>(); |
| Map<String,List<Double>> commonPopulationLocationsA = new HashMap<String,List<Double>>(); |
| Map<String,List<Double>> commonPopulationLocationsB = new HashMap<String,List<Double>>(); |
| |
| Map<String,Double> commonAvgPopulationLocationsA = new HashMap<String,Double>(); |
| Map<String,Double> commonAvgPopulationLocationsB = new HashMap<String,Double>(); |
| Map<String, Double> commonMaxLocationsA = new HashMap<String, Double>(); |
| |
| /** |
| * number common locations with nonzero Inf count at time t |
| */ |
| public double[] locationCount; |
| /** |
| * The Result |
| */ |
| public double[] meanSqDiff; |
| /* |
| * time |
| */ |
| public double[] time; |
| |
| protected AnalysisFactory aFactory = new AnalysisFactoryImpl(); |
| |
| // Set to true to weight the average by population size |
| private static boolean AGGREGATE_NRMSE = true; // True if aggregate signal for locations first, then calculate NRMSE. False if use NRMSE per location then average. |
| private static boolean WEIGHTED_AVERAGE = true; // Only used if AGGREGATE_NRMSE = false; |
| private static boolean FIT_INCIDENCE = true; |
| private static boolean USE_THRESHOLD = false; |
| private static double THRESHOLD = 0.05; |
| |
| // The year to use to validate and hence exclude from the error calculation(for cross-validation methods) |
| // The first year is year 0. If no year should be excluded, set to -1 |
| |
| int validationYear = -1; |
| |
| |
| /** |
| * calculate delta for a simple error function |
| * |
| * |
| * @override |
| */ |
| |
| @Override |
| public ErrorResult calculateError(ReferenceScenarioDataMap reference, ReferenceScenarioDataMap data) { |
| final ReferenceScenarioDataMapImpl _ref = (ReferenceScenarioDataMapImpl)reference; |
| final ReferenceScenarioDataMapImpl _data = (ReferenceScenarioDataMapImpl)data; |
| |
| // clear |
| time = null; |
| |
| Iterator<String> iteratorA = _ref.getLocations().iterator(); |
| int maxTime = -1; |
| while(iteratorA.hasNext()) { |
| String id = iteratorA.next(); |
| |
| if(_data.containsLocation(id)) { |
| // get the lists of data only for those locations that are common to both maps ReferenceScenarioDataInstance dataMapA = mapA.getLocation(id); |
| ReferenceScenarioDataInstance dataMapA = _ref.getLocation(id); |
| List<Double> dataAI = null; |
| if(FIT_INCIDENCE) dataAI = getIncidence(dataMapA); |
| else dataAI = getInfectious(dataMapA); |
| List<Double> dataAP = getPopulation(dataMapA); |
| commonInfectiousLocationsA.put(id,dataAI); |
| commonPopulationLocationsA.put(id, dataAP); |
| |
| // Map B |
| ReferenceScenarioDataInstance dataMapB = _data.getLocation(id); |
| List<Double> dataBI = null; |
| if(FIT_INCIDENCE) dataBI = getIncidence(dataMapB); |
| else dataBI = getInfectious(dataMapB); |
| List<Double> dataBP = getPopulation(dataMapB); |
| commonInfectiousLocationsB.put(id,dataBI); |
| commonPopulationLocationsB.put(id, dataBP); |
| |
| // init the array size |
| if (maxTime == -1) maxTime = dataAI.size(); |
| |
| // dimension the arrays to the length of the SMALLEST array for which we have data |
| if(maxTime >= dataBI.size() ) maxTime = dataBI.size(); |
| if(maxTime >= dataAI.size() ) maxTime = dataAI.size(); |
| }// if |
| }// while |
| if(maxTime<=0) maxTime = 0; |
| if(time==null) { |
| time = new double[maxTime]; |
| meanSqDiff = new double[maxTime]; |
| locationCount = new double[maxTime]; |
| for(int i = 0; i < maxTime; i ++) { |
| time[i] = i; |
| meanSqDiff[i] = 0.0; |
| locationCount[i] = 0.0; |
| } |
| } |
| |
| // Now figure out the actual error |
| |
| double [] Xref = new double[time.length]; |
| double [] Xdata = new double[time.length]; |
| |
| double finalerror = 0.0; |
| double verror = 0.0; |
| |
| BasicEList<Double> list = new BasicEList<Double>(); |
| for(int i=0;i<time.length;++i)list.add(0.0); |
| |
| // Get the average population for each location |
| for(String loc:commonPopulationLocationsA.keySet()) { |
| List<Double>ld = commonPopulationLocationsA.get(loc); |
| double sum = 0;for(double d:ld)sum+=d; |
| sum /= (double)ld.size(); |
| commonAvgPopulationLocationsA.put(loc, sum); |
| } |
| |
| // Get the average population for each location |
| for(String loc:commonPopulationLocationsB.keySet()) { |
| List<Double>ld = commonPopulationLocationsB.get(loc); |
| double sum = 0;for(double d:ld)sum+=d; |
| sum /= (double)ld.size(); |
| commonAvgPopulationLocationsB.put(loc, sum); |
| } |
| |
| // Get the maximum value for the A series (reference) |
| for(String loc:commonPopulationLocationsA.keySet()) { |
| List<Double>ld = commonInfectiousLocationsA.get(loc); |
| double max = Double.MIN_VALUE; |
| for(double d:ld)if(d >max)max=d; |
| commonMaxLocationsA.put(loc, max); |
| } |
| |
| // Calculate the normalized root mean square error for each location, then |
| // divide by the number of locatins |
| |
| double weighted_denom = 0.0; |
| |
| if(!AGGREGATE_NRMSE) { // Use NRMSE per location first |
| for(String loc:commonInfectiousLocationsA.keySet()) { |
| double maxRef = 0.0; |
| double minRef = Double.MAX_VALUE; |
| // Get the numbers at each time step for the location |
| for(int icount =0; icount < time.length; icount ++) { |
| List<Double> dataAI = commonInfectiousLocationsA.get(loc); |
| List<Double> dataBI = commonInfectiousLocationsB.get(loc); |
| |
| double iA = dataAI.get(icount).doubleValue(); |
| double iB = dataBI.get(icount).doubleValue(); |
| |
| Xref[icount]=iA; |
| Xdata[icount]=iB; |
| } |
| |
| double nominator = 0.0; |
| double timesteps = 0; |
| for(int icount =0; icount < time.length; icount ++) { |
| if(Xref[icount]>maxRef)maxRef = Xref[icount]; |
| if(Xref[icount]<minRef)minRef = Xref[icount]; |
| |
| // If we use the threshold and both the reference and the model is less than |
| // the THRESHOLD*MAXref(loc) we don't measure the data point |
| |
| if(USE_THRESHOLD && (Xref[icount]<=THRESHOLD*commonMaxLocationsA.get(loc) && |
| Xdata[icount]<=THRESHOLD*commonMaxLocationsA.get(loc))) continue; |
| |
| nominator = nominator + Math.pow(Xref[icount]-Xdata[icount], 2); |
| list.set(icount, list.get(icount)+Math.abs(Xref[icount]-Xdata[icount])); |
| ++timesteps; |
| } |
| double error = Double.MAX_VALUE; |
| if(timesteps > 0 && maxRef-minRef > 0.0) { |
| error = Math.sqrt(nominator/timesteps); |
| error = error / (maxRef-minRef); |
| if(WEIGHTED_AVERAGE) finalerror += commonAvgPopulationLocationsA.get(loc) * error; |
| else finalerror += error; |
| if(WEIGHTED_AVERAGE) weighted_denom += commonAvgPopulationLocationsA.get(loc); |
| else weighted_denom += 1.0; |
| } |
| |
| } |
| |
| // Divide the error by the number of locations |
| finalerror /= weighted_denom; |
| } else { // Aggregate signal, then calculate NRMSE |
| for(int icount =0; icount < time.length; icount ++) { |
| for(String loc:commonInfectiousLocationsA.keySet()) { |
| List<Double> dataAI = commonInfectiousLocationsA.get(loc); |
| List<Double> dataBI = commonInfectiousLocationsB.get(loc); |
| |
| double iA = dataAI.get(icount).doubleValue(); |
| double iB = dataBI.get(icount).doubleValue(); |
| |
| Xref[icount]+=iA; |
| Xdata[icount]+=iB; |
| } |
| } |
| |
| double maxRef = Double.MIN_VALUE; |
| double minRef = Double.MAX_VALUE; |
| double maxValidationRef = Double.MIN_VALUE; |
| double minValidationRef = Double.MAX_VALUE; |
| |
| for(int icount =0; icount < time.length; icount ++) { |
| if(icount >= validationYear*365.25 && icount <= (validationYear+1)*365.25) { |
| if(Xref[icount]>maxValidationRef)maxValidationRef = Xref[icount]; |
| if(Xref[icount]<minValidationRef)minValidationRef = Xref[icount]; |
| continue; |
| } |
| if(Xref[icount]>maxRef)maxRef = Xref[icount]; |
| if(Xref[icount]<minRef)minRef = Xref[icount]; |
| } |
| double nominator = 0.0, vnominator = 0.0; |
| double timesteps = 0.0, vtimesteps = 0.0; |
| for(int icount =0; icount < time.length; icount ++) { |
| |
| // Calculate validation error then skip |
| if(icount >= validationYear*365.25 && icount <= (validationYear+1)*365.25) { |
| if(USE_THRESHOLD && (Xref[icount]<=THRESHOLD*maxValidationRef && |
| Xdata[icount]<=THRESHOLD*maxValidationRef)) continue; |
| |
| vnominator = vnominator + Math.pow(Xref[icount]-Xdata[icount], 2); |
| list.set(icount, new Double(0)); // Set to 0 for validation data points |
| ++vtimesteps; |
| continue; |
| } |
| // If we use the threshold and both the reference and the model is less than |
| // the THRESHOLD*MAXref(loc) we don't measure the data point |
| |
| if(USE_THRESHOLD && (Xref[icount]<=THRESHOLD*maxRef && |
| Xdata[icount]<=THRESHOLD*maxRef)) continue; |
| |
| nominator = nominator + Math.pow(Xref[icount]-Xdata[icount], 2); |
| list.set(icount, Math.abs(Xref[icount]-Xdata[icount])); |
| ++timesteps; |
| } |
| |
| double error = Double.MAX_VALUE; |
| if(timesteps > 0 && maxRef-minRef > 0.0) { |
| error = Math.sqrt(nominator/timesteps); |
| finalerror = error / (maxRef-minRef); |
| } |
| // Validation |
| error = Double.MAX_VALUE; |
| if(vtimesteps > 0 && maxValidationRef-minValidationRef > 0.0) { |
| error = Math.sqrt(vnominator/vtimesteps); |
| verror = error / (maxValidationRef-minValidationRef); |
| } |
| } // else |
| ErrorResult resultobj = aFactory.createErrorResult(); |
| resultobj.setErrorByTimeStep(list); |
| resultobj.setError(finalerror); |
| resultobj.setValidationError(verror); |
| |
| EList<Double>refByTime = new BasicEList<Double>(); |
| EList<Double>dataByTime = new BasicEList<Double>(); |
| |
| // Set the reference and model by time |
| for(int icount=0;icount<time.length;++icount) { |
| refByTime.add(0.0);dataByTime.add(0.0);} |
| for(String loc:commonInfectiousLocationsA.keySet()) { |
| for(int icount =0; icount < time.length; icount ++) { |
| List<Double> dataAI = commonInfectiousLocationsA.get(loc); |
| List<Double> dataBI = commonInfectiousLocationsB.get(loc); |
| |
| double iA = dataAI.get(icount).doubleValue(); |
| double iB = dataBI.get(icount).doubleValue(); |
| |
| refByTime.set(icount, refByTime.get(icount)+iA); |
| dataByTime.set(icount, dataByTime.get(icount)+iB); |
| } |
| } |
| resultobj.setReferenceByTime(refByTime); |
| resultobj.setModelByTime(dataByTime); |
| return resultobj; |
| } |
| } //SimpleErrorFunctionImpl |