blob: 16c2135c67070ba19086aadc1ee841944619c750 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.openejb.resource.jdbc;
import org.apache.openejb.jee.EjbJar;
import org.apache.openejb.jee.SingletonBean;
import org.apache.openejb.junit.ApplicationComposer;
import org.apache.openejb.testing.Configuration;
import org.apache.openejb.testing.Module;
import org.apache.openejb.resource.jdbc.managed.local.ManagedConnection;
import org.junit.After;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import javax.annotation.Resource;
import javax.ejb.EJB;
import javax.ejb.EJBContext;
import javax.ejb.LocalBean;
import javax.ejb.Singleton;
import javax.sql.DataSource;
import javax.transaction.Transaction;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@RunWith(ApplicationComposer.class)
public class MultiThreadedManagedDataSourceTest {
private static final String URL = "jdbc:hsqldb:mem:multi-tx-managed;hsqldb.tx=MVCC"; // mvcc otherwise multiple transaction tests will fail
private static final String USER = "sa";
private static final String PASSWORD = "";
private static final String TABLE = "PUBLIC.MULTI_TX_MANAGED_DATASOURCE_TEST";
private static final int INSERTS_NB = 200;
@EJB
private Persister persistManager;
@BeforeClass
public static void createTable() throws SQLException, ClassNotFoundException {
Class.forName("org.hsqldb.jdbcDriver");
final Connection connection = DriverManager.getConnection(URL, USER, PASSWORD);
final Statement statement = connection.createStatement();
statement.execute("CREATE TABLE " + TABLE + "(ID INTEGER)");
statement.close();
connection.commit();
connection.close();
}
@Configuration
public Properties config() {
final Properties p = new Properties();
p.put("openejb.jdbc.datasource-creator", "dbcp-alternative");
p.put("managed", "new://Resource?type=DataSource");
p.put("managed.JdbcDriver", "org.hsqldb.jdbcDriver");
p.put("managed.JdbcUrl", URL);
p.put("managed.UserName", USER);
p.put("managed.Password", PASSWORD);
p.put("managed.JtaManaged", "true");
return p;
}
@Module
public EjbJar app() throws Exception {
return new EjbJar()
.enterpriseBean(new SingletonBean(Persister.class).localBean());
}
@LocalBean
@Singleton
public static class Persister {
private static final AtomicInteger ID = new AtomicInteger(1);
@Resource(name = "managed")
private DataSource ds;
@Resource
private EJBContext context;
public int save() throws SQLException {
int id = ID.getAndIncrement();
MultiThreadedManagedDataSourceTest.save(ds, id);
return id;
}
public int saveRollback(boolean ok) throws SQLException {
int id = ID.getAndIncrement();
MultiThreadedManagedDataSourceTest.save(ds, id);
if (!ok) {
context.setRollbackOnly();
}
return id;
}
}
@Test
public void inserts() throws SQLException {
final AtomicInteger errors = new AtomicInteger(0);
final AtomicInteger fail = new AtomicInteger(0);
run(new Runnable() {
@Override
public void run() {
int id = -1;
try {
id = persistManager.save();
} catch (SQLException e) {
errors.incrementAndGet();
}
try {
if (!exists(id)) {
fail.incrementAndGet();
}
} catch (SQLException e) {
errors.incrementAndGet();
}
}
});
assertEquals(0, errors.get());
assertEquals(0, fail.get());
assertEquals(INSERTS_NB, count(""));
}
@Test
public void insertsWithRollback() throws SQLException {
final int count = count("");
final AtomicInteger errors = new AtomicInteger(0);
final AtomicInteger fail = new AtomicInteger(0);
final AtomicInteger ok = new AtomicInteger(0);
final List<Exception> ex = new CopyOnWriteArrayList<Exception>();
run(new Runnable() {
@Override
public void run() {
boolean rollback = Math.random() > 0.5;
if (!rollback) {
ok.incrementAndGet();
}
int id = -1;
try {
id = persistManager.saveRollback(!rollback);
} catch (SQLException e) {
errors.incrementAndGet();
ex.add(e);
}
if (!rollback) {
try {
if (!exists(id)) {
fail.incrementAndGet();
}
} catch (SQLException e) {
errors.incrementAndGet();
ex.add(e);
}
}
}
});
for (Exception e : ex) {
e.printStackTrace(System.err);
}
assertEquals(0, errors.get());
assertEquals(0, fail.get());
assertEquals(ok.get(), count("") - count);
}
@After
public void checkTxMapIsEmpty() throws Exception { // avoid memory leak
final Field map = ManagedConnection.class.getDeclaredField("CONNECTION_BY_TX_BY_DS");
map.setAccessible(true);
final Map<DataSource, Map<Transaction, Connection>> instance = (Map<DataSource, Map<Transaction, Connection>> ) map.get(null);
assertEquals(1, instance.size());
assertEquals(0, instance.values().iterator().next().size());
}
private static boolean exists(int id) throws SQLException {
return count(" WHERE ID = " + id) == 1;
}
private static int count(String where) throws SQLException {
final Connection connection = DriverManager.getConnection(URL, USER, PASSWORD);
final Statement statement = connection.createStatement();
final ResultSet result = statement.executeQuery("SELECT count(*) AS NB FROM " + TABLE + where);
try {
assertTrue(result.next());
return result.getInt(1);
} finally {
statement.close();
connection.close();
}
}
private static void save(final DataSource ds, int id) throws SQLException {
execute(ds, "INSERT INTO " + TABLE + "(ID) VALUES(" + id + ")");
}
private static void execute(final DataSource ds, final String sql) throws SQLException {
final Connection connection = ds.getConnection();
final Statement statement = connection.createStatement();
statement.executeUpdate(sql);
statement.close();
connection.close();
}
private static void execute(final Connection connection, final String sql) throws SQLException {
final Statement statement = connection.createStatement();
statement.executeUpdate(sql);
statement.close();
connection.close();
}
private void run(final Runnable runnable) {
final ExecutorService es = Executors.newFixedThreadPool(20);
for (int i = 0; i < INSERTS_NB; i++) {
es.submit(new Runnable() {
@Override
public void run() {
runnable.run();
}
});
}
es.shutdown();
try {
es.awaitTermination(5, TimeUnit.MINUTES);
} catch (InterruptedException e) {
fail();
}
}
}