blob: 3273473fd08c2af58165db35f632aa30f21afe22 [file] [log] [blame]
/*******************************************************************************
* Copyright (c) 2008, 2010 VMware Inc.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* VMware Inc. - initial contribution
*******************************************************************************/
package org.eclipse.virgo.kernel.userregion.internal.equinox;
import java.io.IOException;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.IllegalClassFormatException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URL;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.ProtectionDomain;
import java.sql.Driver;
import java.sql.DriverManager;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Vector;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;
import org.eclipse.osgi.baseadaptor.BaseData;
import org.eclipse.osgi.baseadaptor.bundlefile.BundleEntry;
import org.eclipse.osgi.baseadaptor.loader.ClasspathEntry;
import org.eclipse.osgi.baseadaptor.loader.ClasspathManager;
import org.eclipse.osgi.framework.adaptor.ClassLoaderDelegate;
import org.eclipse.osgi.internal.baseadaptor.DefaultClassLoader;
import org.eclipse.osgi.service.resolver.PlatformAdmin;
import org.osgi.framework.Bundle;
import org.osgi.framework.BundleContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.eclipse.virgo.kernel.osgi.framework.ExtendedClassNotFoundException;
import org.eclipse.virgo.kernel.osgi.framework.ExtendedNoClassDefFoundError;
import org.eclipse.virgo.kernel.osgi.framework.InstrumentableClassLoader;
import org.eclipse.virgo.kernel.osgi.framework.OsgiFrameworkUtils;
/**
* Extension to {@link DefaultClassLoader} that adds instrumentation support.
* <p/>
*
* <strong>Concurrent Semantics</strong><br />
*
* As threadsafe as <code>DefaultClassLoader</code>.
*
*/
public final class KernelBundleClassLoader extends DefaultClassLoader implements InstrumentableClassLoader {
static {
try {
Method parallelCapableMethod = ClassLoader.class.getDeclaredMethod("registerAsParallelCapable", (Class[]) null);
parallelCapableMethod.setAccessible(true);
parallelCapableMethod.invoke(null, new Object[0]);
} catch (Throwable e) {
// must avoid failing in clinit
}
}
private static final String[] EXCLUDED_PACKAGES = new String[] { "java.", "javax.", "sun.", "oracle." };
private static final String HEADER_INSTRUMENT_PACKAGE = "Instrument-Package";
private static final Logger LOGGER = LoggerFactory.getLogger(KernelBundleClassLoader.class);
private final List<ClassFileTransformer> classFileTransformers = new CopyOnWriteArrayList<ClassFileTransformer>();
private final String[] instrumentedPackages;
private final String[] classpath;
private final String bundleScope;
private final Set<Class<Driver>> loadedDriverClasses = new HashSet<Class<Driver>>();
private final Object monitor = new Object();
private volatile boolean instrumented;
/**
* Constructs a new <code>ServerBundleClassLoader</code>.
*
* @param parent the parent <code>ClassLoader</code>.
* @param delegate the delegate for this ClassLoader</code>
* @param domain the domain for this ClassLoader</code>
* @param bundledata the bundledata for this ClassLoader</code>
* @param classpath the classpath for this ClassLoader</code>
*/
KernelBundleClassLoader(ClassLoader parent, ClassLoaderDelegate delegate, ProtectionDomain domain, BaseData bundledata, String[] classpath) {
super(parent, delegate, domain, bundledata, classpath);
this.classpath = classpath;
this.bundleScope = OsgiFrameworkUtils.getScopeName(bundledata.getBundle());
this.instrumentedPackages = findInstrumentedPackages(bundledata.getBundle());
}
/**
* {@inheritDoc}
*/
public void addClassFileTransformer(ClassFileTransformer transformer) {
this.instrumented = true;
synchronized (this.classFileTransformers) {
if (this.classFileTransformers.contains(transformer)) {
return;
}
this.classFileTransformers.add(transformer);
}
Bundle[] bundles = getDependencyBundles(false);
for (Bundle bundle : bundles) {
if (propagateInstrumentationTo(bundle)) {
ClassLoader bundleClassLoader = getBundleClassLoader(bundle);
if (bundleClassLoader instanceof KernelBundleClassLoader) {
((KernelBundleClassLoader) bundleClassLoader).addClassFileTransformer(transformer);
}
}
}
}
/**
* @param bundle
* @return
*/
private ClassLoader getBundleClassLoader(Bundle bundle) {
return EquinoxUtils.getBundleClassLoader(bundle);
}
/**
* {@inheritDoc}
*/
@Override
protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
try {
Class<?> loadedClass = super.loadClass(name, resolve);
storeClassIfDriver(loadedClass);
return loadedClass;
} catch (ClassNotFoundException e) {
throw new ExtendedClassNotFoundException(this, e);
} catch (NoClassDefFoundError e) {
throw new ExtendedNoClassDefFoundError(this, e);
}
}
/**
* {@inheritDoc}
*/
public boolean isInstrumented() {
return this.instrumented;
}
/**
* {@inheritDoc}
*/
public int getClassFileTransformerCount() {
return this.classFileTransformers.size();
}
/**
* Finds the explicit list of packages to include in instrumentation (if specified).
*/
private String[] findInstrumentedPackages(Bundle bundle) {
String headerValue = (String) bundle.getHeaders().get(HEADER_INSTRUMENT_PACKAGE);
if (headerValue == null || headerValue.length() == 0) {
return new String[0];
} else {
String[] vals = headerValue.split(",");
String[] packageNames = new String[vals.length];
for (int x = 0; x < packageNames.length; x++) {
packageNames[x] = vals[x].trim();
}
return packageNames;
}
}
private boolean propagateInstrumentationTo(Bundle bundle) {
return !EquinoxUtils.isSystemBundle(bundle) && this.bundleScope != null && this.bundleScope.equals(OsgiFrameworkUtils.getScopeName(bundle));
}
private boolean shouldInstrument(String className) {
return includedForInstrumentation(className) && !excludedFromInstrumentation(className);
}
private boolean includedForInstrumentation(String className) {
if (this.instrumentedPackages.length == 0) {
return true;
}
for (String instrumentedPackage : this.instrumentedPackages) {
if (className.startsWith(instrumentedPackage)) {
return true;
}
}
return false;
}
private boolean excludedFromInstrumentation(String className) {
for (String excludedPackage : EXCLUDED_PACKAGES) {
if (className.startsWith(excludedPackage)) {
return true;
}
}
return false;
}
/**
* {@inheritDoc}
*/
public ThrowAwayClassLoader createThrowAway() {
final ClasspathManager manager = new ClasspathManager(this.manager.getBaseData(), this.classpath, this);
manager.initialize();
return AccessController.doPrivileged(new PrivilegedAction<ThrowAwayClassLoader>() {
public ThrowAwayClassLoader run() {
return new ThrowAwayClassLoader(manager);
}
});
}
/**
* {@inheritDoc}
*/
@Override
public Class<?> defineClass(String name, byte[] classbytes, ClasspathEntry classpathEntry, BundleEntry entry) {
byte[] transformedBytes = classbytes;
if (shouldInstrument(name)) {
for (ClassFileTransformer transformer : this.classFileTransformers) {
try {
String transformName = name.replaceAll("\\.", "/");
byte[] transform = transformer.transform(this, transformName, null, this.domain, transformedBytes);
if (transform != null) {
transformedBytes = transform;
}
} catch (IllegalClassFormatException e) {
throw new ClassFormatError("Error reading class from bundle entry '" + entry.getName() + "'. " + e.getMessage());
}
}
}
try {
Class<?> definedClass = super.defineClass(name, transformedBytes, classpathEntry, entry);
storeClassIfDriver(definedClass);
return definedClass;
} catch (NoClassDefFoundError e) {
throw new ExtendedNoClassDefFoundError(this, e);
}
}
/**
* @param definedClass
*/
@SuppressWarnings("unchecked")
private void storeClassIfDriver(Class<?> candidateClass) {
if (Driver.class.isAssignableFrom(candidateClass)) {
synchronized (this.monitor) {
this.loadedDriverClasses.add((Class<Driver>) candidateClass);
}
}
}
@Override
public void close() {
clearJdbcDrivers();
}
private void clearJdbcDrivers() {
Set<Class<Driver>> localLoadedDriverClasses;
synchronized (this.monitor) {
localLoadedDriverClasses = new HashSet<Class<Driver>>(this.loadedDriverClasses);
}
synchronized (DriverManager.class) {
try { // Java 6
Field writeDriversField = DriverManager.class.getDeclaredField("writeDrivers");
writeDriversField.setAccessible(true);
Vector<?> writeDrivers = (Vector<?>) writeDriversField.get(null);
Iterator<?> driverElements = writeDrivers.iterator();
while (driverElements.hasNext()) {
Object driverObj = driverElements.next();
Field driverField = driverObj.getClass().getDeclaredField("driver");
driverField.setAccessible(true);
if (localLoadedDriverClasses.contains(driverField.get(driverObj).getClass())) {
driverElements.remove();
}
}
Vector<?> readDrivers = (Vector<?>) writeDrivers.clone();
Field readDriversField = DriverManager.class.getDeclaredField("readDrivers");
readDriversField.setAccessible(true);
readDriversField.set(null, readDrivers);
LOGGER.debug("Cleared JDBC drivers for " + this + " using Java 6 strategy");
} catch (Exception javaSixWayFailed) {
try { // Java 7
Field registeredDriversField = DriverManager.class.getDeclaredField("registeredDrivers");
registeredDriversField.setAccessible(true);
CopyOnWriteArrayList<?> registeredDrivers = (CopyOnWriteArrayList<?>) registeredDriversField.get(null);
Iterator<?> driverElements = registeredDrivers.iterator();
List<Object> driverElementsToRemove = new ArrayList<Object>();
while (driverElements.hasNext()) {
Object driverObj = driverElements.next();
Field driverField = driverObj.getClass().getDeclaredField("driver");
driverField.setAccessible(true);
if (localLoadedDriverClasses.contains(driverField.get(driverObj).getClass())) {
driverElementsToRemove.add(driverObj);
}
}
for (Object driverObj : driverElementsToRemove) {
registeredDrivers.remove(driverObj);
}
LOGGER.debug("Cleared JDBC drivers for " + this + " using Java 7 strategy");
} catch (Exception e) {
LOGGER.warn("Failure when clearing JDBC drivers for " + this, e);
}
}
}
}
@Override
public String toString() {
return String.format("%s: [bundle=%s]", getClass().getSimpleName(), this.delegate);
}
private Bundle[] getDependencyBundles(boolean includeDependenciesFragments) {
Bundle bundle = this.manager.getBaseData().getBundle();
BundleContext systemBundleContext = getBundleContext();
PlatformAdmin serverAdmin = getPlatformAdmin();
Bundle[] deps = EquinoxUtils.getDirectDependencies(bundle, systemBundleContext, serverAdmin, includeDependenciesFragments);
return deps;
}
/**
* Gets the {@link BundleContext} for this ClassLoader's {@link Bundle}.
*
* @return the <code>BundleContext</code>.
*/
private BundleContext getBundleContext() {
return this.manager.getBaseData().getAdaptor().getContext();
}
/**
* Gets the {@link PlatformAdmin} service.
*
* @return the <code>PlatformAdmin</code> service.
*/
private PlatformAdmin getPlatformAdmin() {
return this.manager.getBaseData().getAdaptor().getPlatformAdmin();
}
/**
* Throwaway classloader for OSGi bundles.
* <p/>
*
* <strong>Concurrent Semantics</strong><br />
*
* As threadsafe as {@link ClassLoader}.
*
*/
final class ThrowAwayClassLoader extends ClassLoader {
private final ConcurrentMap<String, Class<?>> loadedClasses = new ConcurrentHashMap<String, Class<?>>();
private final ClasspathManager manager;
/**
* @param manager
*/
private ThrowAwayClassLoader(ClasspathManager manager) {
this.manager = manager;
}
/**
* {@inheritDoc}
*/
@Override
public Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
if (!shouldInstrument(name)) {
return KernelBundleClassLoader.this.loadClass(name, resolve);
}
Class<?> cls = KernelBundleClassLoader.this.findLoadedClass(name);
if (cls == null) {
cls = this.loadedClasses.get(name);
if (cls == null) {
cls = findClassInternal(name, true);
if (cls == null) {
cls = KernelBundleClassLoader.this.loadClass(name, resolve);
}
}
}
if (cls == null) {
throw new ClassNotFoundException(name);
}
if (resolve) {
resolveClass(cls);
}
// TODO review findbugs warning vs. failing tests in LoadTimeWeavingTests
this.loadedClasses.putIfAbsent(name, cls);
return cls;
}
/**
* Attempts to find a <code>Class</code> from the enclosing bundle.
*
* @param name the name of the <code>Class</code>
* @param traverseDependencies should dependency bundles be checked for the class.
*/
Class<?> findClassInternal(String name, boolean traverseDependencies) {
String path = name.replaceAll("\\.", "/").concat(".class");
BundleEntry entry = this.manager.findLocalEntry(path);
if (entry == null) {
if (traverseDependencies) {
return findClassFromImport(name);
} else {
return null;
}
}
byte[] bytes;
try {
bytes = entry.getBytes();
} catch (IOException e) {
bytes = null;
}
return bytes == null ? null : defineClass(name, bytes, 0, bytes.length);
}
/**
* Attempts to locate a <code>Class</code> from one of the imported bundles.
*
* @param name the <code>Class</code> name.
* @return the located <code>Class</code>, or <code>null</code> if no <code>Class</code> can be found.
*/
private Class<?> findClassFromImport(String name) {
Bundle[] deps = getDependencyBundles(false);
for (Bundle dep : deps) {
ClassLoader depClassLoader = getBundleClassLoader(dep);
if (depClassLoader instanceof KernelBundleClassLoader) {
KernelBundleClassLoader pbcl = (KernelBundleClassLoader) depClassLoader;
Class<?> loadedClass = pbcl.publicFindLoaded(name);
if (loadedClass != null) {
return loadedClass;
}
ThrowAwayClassLoader throwAway = pbcl.createThrowAway();
Class<?> cls = throwAway.findClassInternal(name, false);
if (cls != null) {
return cls;
}
}
}
return null;
}
/**
* {@inheritDoc}
*/
@Override
public URL getResource(String name) {
return KernelBundleClassLoader.this.getResource(name);
}
/**
* {@inheritDoc}
*/
@Override
public Enumeration<URL> getResources(String name) throws IOException {
return KernelBundleClassLoader.this.getResources(name);
}
}
}