blob: dc79b9e950a131864fffbecd84b8f7c3865183ed [file] [log] [blame]
/*******************************************************************************
* Copyright (c) 2013 IBM Corp.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Takahiro Inaba - initial API and implementation and/or initial
* documentation
*******************************************************************************/
package com.ibm.jmeter.protocol.mqtt;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.codec.binary.Base64;
import org.apache.jmeter.samplers.AbstractSampler;
import org.apache.jmeter.samplers.Entry;
import org.apache.jmeter.samplers.SampleResult;
import org.apache.jmeter.testbeans.TestBean;
import org.apache.jmeter.testelement.TestStateListener;
import org.apache.jorphan.logging.LoggingManager;
import org.apache.log.Logger;
public class MqttSampler extends AbstractSampler
implements TestBean, TestStateListener {
private static final long serialVersionUID = 300L;
private static final Logger log = LoggingManager.getLoggerForClass();
private static final ThreadLocal<Map<String, IMqttClient>> tlClientCache = new ThreadLocal<Map<String, IMqttClient>>();
private static final List<IMqttClient> allCreatedClients = Collections.synchronizedList(new ArrayList<IMqttClient>());
private String serverURI;
private String clientId;
private boolean reuseConnection;
private boolean closeConnection;
private String topicString;
private int qos;
private boolean retained;
private boolean base64Encoded;
private String characterEncoding;
private String messageBody;
public MqttSampler() {
log.debug("created " + System.identityHashCode(this) + ", " + Thread.currentThread().getId());
}
@Override
public SampleResult sample(Entry entry) {
SampleResult res = new SampleResult();
boolean isOK = false;
res.setSampleLabel("MQTT Sampler");
if (log.isDebugEnabled()) {
log.debug(toString());
}
try {
send(res);
isOK = true;
} catch (Exception e) {
log.error("An error was occurred during the message sending.", e);
res.setResponseMessage(e.toString());
}
res.setSamplerData(messageBody);
res.setSuccessful(isOK);
return res;
}
private void send(SampleResult result) throws Exception {
MqttMessage message = new MqttMessage(createMessageBody());
message.setQos(qos);
message.setRetained(retained);
result.sampleStart();
IMqttClient client = null;
Map<String, IMqttClient> clientCache = tlClientCache.get();
String clientCacheKey = createCacheKey(clientId, serverURI);
if (clientCache == null) {
clientCache = new HashMap<String, IMqttClient>();
tlClientCache.set(clientCache);
} else {
client = clientCache.get(clientCacheKey);
}
if (!reuseConnection || client == null || !client.isConnected()) {
close(client);
clientCache.remove(clientCacheKey);
client = createClient();
if (log.isDebugEnabled()) {
log.debug("A client was created. (" + client.getClientId() + ":" + System.identityHashCode(client) + ")");
}
client.connect();
if (log.isDebugEnabled()) {
log.debug("A client has been connected. (" + client.getClientId() + ":" + System.identityHashCode(client) + ")");
}
clientCache.put(clientCacheKey, client);
allCreatedClients.add(client);
} else {
if (log.isDebugEnabled()) {
log.debug("Got a client from the cache." + System.identityHashCode(client));
}
}
client.publish(topicString, message);
if (log.isDebugEnabled()) {
log.debug("A message was sent.");
}
if (closeConnection) {
close(client);
clientCache.remove(clientCacheKey);
}
result.sampleEnd();
}
protected String createCacheKey(String clientId, String serverURL) {
return clientId + "@" + serverURI;
}
protected IMqttClient createClient() throws MqttException {
IMqttClient client = null;
if (serverURI.startsWith("tcp:") || serverURI.startsWith("ssl:")) {
client = new MqttOverTcpIpClient(serverURI, clientId);
} else {
throw new MqttException("Unknown protocol: " + serverURI);
}
if (log.isDebugEnabled()) {
log.debug("A client has been created. (" + client.getClientId() + ":" + System.identityHashCode(client) + ")");
}
return client;
}
private void close(IMqttClient client) {
if (client != null && client.isConnected()) {
try {
client.close();
if (log.isDebugEnabled()) {
log.debug("A client has been closed. (" + client.getClientId() + ":" + System.identityHashCode(client) + ")");
}
} catch (MqttException e) {
log.warn("Closing the connection failed.", e);
}
}
}
protected byte[] createMessageBody() {
if (base64Encoded) {
return Base64.decodeBase64(messageBody);
} else {
if (characterEncoding == null || characterEncoding.length() == 0) {
return messageBody.getBytes();
} else {
return messageBody.getBytes(Charset.forName(characterEncoding));
}
}
}
@Override
public void testEnded() {
log.debug("testEnded " + System.identityHashCode(this) + ", " + Thread.currentThread().getId());
synchronized (allCreatedClients) {
for (IMqttClient c : allCreatedClients) {
close(c);
}
}
allCreatedClients.clear();
}
@Override
public void testEnded(String arg0) {
log.debug("testEnded (" + arg0 + ") " + System.identityHashCode(this) + ", " + Thread.currentThread().getId());
testEnded();
}
@Override
public void testStarted() {
log.debug("testStarted " + System.identityHashCode(this) + ", " + Thread.currentThread().getId());
}
@Override
public void testStarted(String arg0) {
log.debug("testStarted (" + arg0 + ") " + System.identityHashCode(this) + ", " + Thread.currentThread().getId());
testStarted();
}
@Override
public String toString() {
return "MqttSampler [serverURI=" + serverURI + ", " +
"clientId=" + clientId + ", " +
"reuseConnection=" + reuseConnection + ", " +
"closeConnection=" + closeConnection + ", " +
"topicString=" + topicString + ", " +
"qos=" + qos + ", " +
"retained=" + retained + ", " +
"base64Encoded=" + base64Encoded + ", " +
"characterEncoding=" + characterEncoding + ", " +
"messageBody=" + messageBody + "]";
}
public String getServerURI() {
return serverURI;
}
public void setServerURI(String serverURI) {
this.serverURI = serverURI;
}
public String getClientId() {
return clientId;
}
public void setClientId(String clientId) {
this.clientId = clientId;
}
public boolean getReuseConnection() {
return reuseConnection;
}
public void setReuseConnection(boolean reuseConnection) {
this.reuseConnection = reuseConnection;
}
public boolean getCloseConnection() {
return closeConnection;
}
public void setCloseConnection(boolean closeConnection) {
this.closeConnection = closeConnection;
}
public String getTopicString() {
return topicString;
}
public void setTopicString(String topicString) {
this.topicString = topicString;
}
public int getQos() {
return qos;
}
public void setQos(int qos) {
this.qos = qos;
}
public boolean getRetained() {
return retained;
}
public void setRetained(boolean retained) {
this.retained = retained;
}
public boolean isBase64Encoded() {
return base64Encoded;
}
public void setBase64Encoded(boolean base64Encoded) {
this.base64Encoded = base64Encoded;
}
public String getCharacterEncoding() {
return characterEncoding;
}
public void setCharacterEncoding(String characterEncoding) {
this.characterEncoding = characterEncoding;
}
public String getMessageBody() {
return messageBody;
}
public void setMessageBody(String messageBody) {
this.messageBody = messageBody;
}
}