///////////////////////////////////////////////////////////////////////////////
//
// Copyright (c) 2000-2018 Ericsson Telecom AB
//
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// which accompanies this distribution, and is available at
// http://www.eclipse.org/legal/epl-v10.html
///////////////////////////////////////////////////////////////////////////////
//
//  File:               TLS_HandlerFunction.cc
//  Description:        external functions for TLS
//  Rev:                R3A
//  Prodnr:             CNL 113 839
//
#include "TLS_Handler.hh"
#include <openssl/ssl.h>
#include <openssl/rand.h>
#include <openssl/err.h>
#include <openssl/bio.h>

static int lib_initialized=0;


class TLS_object {
public:
  SSL_CTX* ctx;  // main ssl context
  SSL* ssl;      // the SSL* which represents a "connection"
  BIO* in_bio;   // we use memory read bios
  BIO* out_bio;  // we use memory write bios
  CHARSTRING passwd; // stores the password of the private key
  INTEGER  user_idx; // stores the user supplied idx value
  bool  handshake_started;
  CHARSTRING pskIdentity;
  CHARSTRING pskKey;
  TLS_object();
  ~TLS_object();
  
  int init(const TLS__Handler::TLS__descriptor&);
  
  void log_error();
  
};

TLS_object** obj_list=NULL;
int obj_list_size=0;
int obj_count=0;


int get_new_obj(){
  if(obj_list_size==obj_count){
    obj_list_size++;
    obj_list=(TLS_object**)Realloc(obj_list,sizeof(TLS_object*)*obj_list_size);
    for(int i=obj_count;i<obj_list_size;i++){
      obj_list[i]=NULL;
    }
  }
  int i=0;
  for(;i<obj_list_size;i++){
    if(obj_list[i]==NULL){
      break;
    }
  }
  obj_list[i]= new TLS_object;
  obj_count++;
  return i;
}

void rm_obj(int id){
  delete obj_list[id];
  obj_list[id]=NULL;
  obj_count--;
  if(obj_count==0){
    Free(obj_list);
    obj_list=NULL;
    obj_list_size=0;
  }
}

TLS_object::TLS_object(){
  ctx=NULL;
  ssl=NULL;
  in_bio=NULL;
  out_bio=NULL;
  handshake_started=false;
  if(!lib_initialized){
    SSL_library_init();          // initialize library
    SSL_load_error_strings();    // readable error messages
    lib_initialized=1;
  }
}

TLS_object::~TLS_object(){
  if(ssl){
    SSL_free(ssl);
  }
  if(ctx){
    SSL_CTX_free(ctx);
  }
}


void TLS_object::log_error(){
  unsigned long e=ERR_get_error();
  if(!e) {
    TTCN_warning("There is no SSL error at the moment.\n");
  }
  while (e) {
    TTCN_warning("SSL error queue content:");
    TTCN_warning("  Library:  %s", ERR_lib_error_string(e));
    TTCN_warning("  Function: %s", ERR_func_error_string(e));
    TTCN_warning("  Reason:   %s", ERR_reason_error_string(e));
    e=ERR_get_error();
  }
  
}

// Callback function used by OpenSSL.
// Called when a password is needed to decrypt the private key file.
// NOTE: not thread safe
int ssl_password_cb(char *buf, int num, int /*rwflag*/,void *userdata) {

    const char* pass = (const char*) userdata;
    if(userdata==NULL) return 0;
    int pass_len = strlen(pass) + 1;
    if (num < pass_len) return 0;

    strcpy(buf, pass);
    return(strlen(pass));
}

unsigned int psk_server_cb(SSL *ssl, const char *identity,
	      unsigned char *psk, unsigned int max_psk_len)
{
  int ret;

  int idx = *(int *)SSL_get_app_data(ssl);
  TLS_object* obj=obj_list[idx];
  if (!identity)return 0;
  if (strcmp(identity,obj->pskIdentity ) != 0) return 0;
  if (strlen(obj->pskKey)>=(max_psk_len*2)) return 0;

  /* convert the PSK key to binary */
  ret = strlen(obj->pskKey)/2;
  memcpy(psk,str2oct(obj->pskKey),ret);
  
  if (ret<=0) return 0;
  return ret;
}

unsigned int psk_client_cb(SSL *ssl, const char *hint, char *identity,
	unsigned int max_identity_len, unsigned char *psk, unsigned int max_psk_len)
{
  int ret;

  int idx = *(int *)SSL_get_app_data(ssl);
  TLS_object* obj=obj_list[idx];

  if (!hint){
    TTCN_warning("NULL received PSK identity hint, continuing anyway");
  }
  ret = snprintf(identity, max_identity_len, "%s", (const char *)(obj->pskIdentity));
  if (ret < 0 || (unsigned int)ret > max_identity_len) return 0;
  if (strlen(obj->pskKey)>=(max_psk_len*2)) return 0;

  ret = strlen(obj->pskKey)/2;
  /* convert the PSK key to binary */ 
  memcpy(psk,str2oct(obj->pskKey),ret);
 
  if (ret<=0) return 0;
  return ret;
}

// How should we call the generic SSL/TLS method
// it depends on the OpenSSL version
#ifndef SSLv23_method
  #define TLS_method  SSLv23_method
#endif


// How should we call the generic DTLS method
// it depends on the OpenSSL version
// Also the DTLSv2 is supported from 1.0.2
#ifndef SSL_OP_NO_DTLSv1_2
  #define DTLS_method DTLSv1_method
#endif


int TLS_object::init(const TLS__Handler::TLS__descriptor& descr){
  if(ctx!=NULL){  // already initialized
    return -1;
  }
  
// create the context
  if(descr.tls__method()==TLS__Handler::TLS__method::TLS__method__DTLS){
    ctx=SSL_CTX_new(DTLS_method());
  }
  else {
    ctx=SSL_CTX_new(TLS_method());
  }

  if(ctx==NULL){
    log_error();
    return -1;
  }
  
// load certificate file
  if(descr.ssl__certificate__file().ispresent()){
    const char* cf=(const char*)descr.ssl__certificate__file()();
    if(SSL_CTX_use_certificate_chain_file(ctx, cf)!=1)
    {
      TTCN_warning("Can't read certificate file %s", cf);
      log_error();
      return -1;
    }
  }

// set private key passwd
  if(descr.ssl__password().ispresent()){
    passwd=descr.ssl__certificate__file()();
    const char* cf=(const char*)passwd;
    SSL_CTX_set_default_passwd_cb(ctx, ssl_password_cb);
    SSL_CTX_set_default_passwd_cb_userdata(ctx, (void *)cf);
  }

// set private key file
  if(descr.ssl__key__file().ispresent()){
    const char* cf=(const char*)descr.ssl__key__file()();
    if(SSL_CTX_use_PrivateKey_file(ctx, cf, SSL_FILETYPE_PEM)!=1)
    {
      TTCN_warning("Can't read key file %s", cf);
      log_error();
      return -1;
    }
  }

// Load trusted CA list
  if(descr.ssl__trustedCAlist__file().ispresent()){
    const char* cf=(const char*)descr.ssl__trustedCAlist__file()();
    if (SSL_CTX_load_verify_locations(ctx, cf, NULL)!=1)
    {
      TTCN_warning("Can't read trustedCAlist file %s", cf);
      log_error();
      return -1;
    }
  }

// Set cipher list
  if(descr.ssl__cipher__list().ispresent()){
    const char* cf=(const char*)descr.ssl__cipher__list()();
    if (SSL_CTX_set_cipher_list(ctx, cf)!=1)
    {
      TTCN_warning("Cipher list restriction failed for %s", cf);
      log_error();
      return -1;
    }
  }
  
//PSK
#if OPENSSL_VERSION_NUMBER >= 0x1000000fL
  if(descr.psk__hint().ispresent()){
    SSL_CTX_use_psk_identity_hint(ctx,(const char*)descr.psk__hint()() );
  }
  if(descr.psk__identity().ispresent()){
    pskIdentity=descr.psk__identity()();
  }
  else pskIdentity="";
  if(descr.psk__key().ispresent()){
    pskKey=descr.psk__key()();
  }
  else pskKey="";
  if(descr.psk__for__server().ispresent() && descr.psk__for__server()()){
    SSL_CTX_set_psk_server_callback(ctx, psk_server_cb);
  }
  else{
    SSL_CTX_set_psk_client_callback(ctx, psk_client_cb);
  }
#else
    TTCN_warning("The used OpenSSL doesn't support the PSK");
#endif

// set other side verification mode
// By  default the verification is enabled
  if(descr.ssl__verify__certificate().ispresent() && !descr.ssl__verify__certificate()()){
    SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);
#if OPENSSL_VERSION_NUMBER >= 0x1000000fL
  } else if(pskKey!="" && pskIdentity!="") {
    SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);
#endif
  } else {
    SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL);
  }
  SSL_CTX_set_read_ahead(ctx, 1);


// create the ssl object
  ssl=SSL_new(ctx);
  if(ssl==NULL){
      log_error();
      return -1;
  }

// set the minimum supported SSL/TLS/DTLS version

  if(descr.min__supported__version().ispresent()){
    // There is a nice new function for this in OpenSSL 1.1.0 but support of the earlier versions is need
    switch(descr.min__supported__version()()){
// DTLS part
      case TLS__Handler::TLS__Supported__proto__versions::TLS__DTLS1__2__VERSION:
#ifdef SSL_OP_NO_DTLSv1
        SSL_set_options(ssl,SSL_OP_NO_DTLSv1);
#else
#endif
        break;

// SSL/TLS part
      case TLS__Handler::TLS__Supported__proto__versions::TLS__TLS1__2__VERSION:
#ifdef  SSL_OP_NO_TLSv1_1
        SSL_set_options(ssl,SSL_OP_NO_TLSv1_1);
#else
#endif
        // no break
      case TLS__Handler::TLS__Supported__proto__versions::TLS__TLS1__1__VERSION:
#ifdef SSL_OP_NO_TLSv1
        SSL_set_options(ssl,SSL_OP_NO_TLSv1);
#else
#endif
        // no break
      case TLS__Handler::TLS__Supported__proto__versions::TLS__TLS1__VERSION:
#ifdef SSL_OP_NO_SSLv3
        SSL_set_options(ssl,SSL_OP_NO_SSLv3);
#else
#endif
        // no break
      case TLS__Handler::TLS__Supported__proto__versions::TLS__SSL3__VERSION:
#ifdef SSL_OP_NO_SSLv2
        SSL_set_options(ssl,SSL_OP_NO_SSLv2);
#else
#endif
        // no break
      default:
      ;
        // no break
    }
  }

// set the maximum supported SSL/TLS/DTLS version

  if(descr.min__supported__version().ispresent()){
    // There is a nice new function for this in OpenSSL 1.1.0 but support of the earlier versions is need
    switch(descr.max__supported__version()()){
// DTLS part
      case TLS__Handler::TLS__Supported__proto__versions::TLS__DTLS1__VERSION:
#ifdef SSL_OP_NO_DTLSv1_2
        SSL_set_options(ssl,SSL_OP_NO_DTLSv1_2);
#else
#endif
        break;

// SSL/TLS part

      case TLS__Handler::TLS__Supported__proto__versions::TLS__SSL3__VERSION:
#ifdef SSL_OP_NO_TLSv1
        SSL_set_options(ssl,SSL_OP_NO_TLSv1);
#else
#endif
        // no break
      case TLS__Handler::TLS__Supported__proto__versions::TLS__TLS1__VERSION:
#ifdef  SSL_OP_NO_TLSv1_1
        SSL_set_options(ssl,SSL_OP_NO_TLSv1_1);
#else
#endif
      case TLS__Handler::TLS__Supported__proto__versions::TLS__TLS1__1__VERSION:
#ifdef  SSL_OP_NO_TLSv1_2
        SSL_set_options(ssl,SSL_OP_NO_TLSv1_2);
#else
#endif
      case TLS__Handler::TLS__Supported__proto__versions::TLS__TLS1__2__VERSION:
        // no break
      default:
      ;
        // no break
    }
  }

// create the memory bios
  /* bios */
  in_bio = BIO_new(BIO_s_mem());
  if(in_bio == NULL) {
    TTCN_warning("Error: cannot allocate read bio.\n");
    log_error();
    return -1;
  }
 
  BIO_set_mem_eof_return(in_bio, -1); /* see: https://www.openssl.org/docs/crypto/BIO_s_mem.html */
 
  out_bio = BIO_new(BIO_s_mem());
  if(out_bio == NULL) {
    TTCN_warning("Error: cannot allocate write bio.\n");
    log_error();
    return -1;
  }
 
  BIO_set_mem_eof_return(out_bio, -1); /* see: https://www.openssl.org/docs/crypto/BIO_s_mem.html */
 
  SSL_set_bio(ssl, in_bio,out_bio);
 

  return 0;
}


namespace TLS__Handler{
bool is_valid_objid(int idx){
  return (idx>=0) && (idx<obj_list_size) && (obj_list[idx]);
}

TLS__op__result TLS__New__object(TLS__Handler::TLS__descriptor const& descr, INTEGER& object__id, INTEGER const& user__idx){
  int idx=get_new_obj();
  if(obj_list[idx]->init(descr)<0){
    // something went wrong
    // drop the object
    rm_obj(idx);
    return TLS__op__result::TLS__ERROR;
  }
  obj_list[idx]->user_idx=user__idx;
  object__id=idx;
  
  return TLS__op__result::TLS__OK;

}

TLS__op__result TLS__get__user__idx(INTEGER const& object__id, INTEGER& user__idx){
  if(is_valid_objid(object__id)){
    user__idx=obj_list[object__id]->user_idx;
    return TLS__op__result::TLS__OK;
  }
  TTCN_warning("Invalid object_id");
  return TLS__op__result::TLS__ERROR;
}

TLS__op__result TLS__set__user__idx(INTEGER const& object__id, INTEGER const& user__idx){
  if(is_valid_objid(object__id)){
    obj_list[object__id]->user_idx=user__idx;
    return TLS__op__result::TLS__OK;
  }
  TTCN_warning("Invalid object_id");
  return TLS__op__result::TLS__ERROR;
}

TLS__op__result TLS__Delete__object(INTEGER const& object__id){
  if(is_valid_objid(object__id)){
    rm_obj(object__id);
    return TLS__op__result::TLS__OK;
  }
  TTCN_warning("Invalid object_id");
  return TLS__op__result::TLS__ERROR;
}

TLS__op__result TLS__Handshake(INTEGER const& object__id, 
                               BOOLEAN const& is__server, 
                               OCTETSTRING const& input__stream, 
                               OCTETSTRING& output__stream){
  if(is_valid_objid(object__id)){
    int idx=object__id;
//printf("idx %d\r\n",idx);    
    TLS__op__result ret_code=TLS__op__result::TLS__ERROR;
    TLS_object* obj=obj_list[idx];
    
    //used in psk cb functions
    SSL_set_app_data(obj->ssl,(const char*)&idx);

    if(!(obj->handshake_started)){
      if((bool)is__server) {
        SSL_set_accept_state(obj->ssl);
      } 
      else {
        SSL_set_connect_state(obj->ssl);
      }
      
      obj->handshake_started = true;    
    }
//printf("obj->handshake_started %d\r\n",obj->handshake_started);    
    if(input__stream.lengthof()>0){
//printf("BIO_write\r\n");    
      BIO_write(obj->in_bio, (const unsigned char*)input__stream, input__stream.lengthof());
    }
    int res=SSL_do_handshake(obj->ssl);
//printf("res %d\r\n",res);    
    if(res==1){
      // Handshake completed
      ret_code=TLS__op__result::TLS__OK;
    } else {
      // Check the error code
      switch(SSL_get_error(obj->ssl,res)){
        case SSL_ERROR_WANT_READ:
//printf("SSL_ERROR_WANT_READ %d\r\n",res);    
          ret_code=TLS__op__result::TLS__NEED__MORE__DATA;
          break;
        case SSL_ERROR_WANT_WRITE:
//printf("SSL_ERROR_WANT_WRITE %d\r\n",res);    
          ret_code=TLS__op__result::TLS__DATA__TO__SEND;
          break;
        default:
//printf("other %d\r\n",res);    
          ret_code=TLS__op__result::TLS__ERROR;
          TTCN_warning("TLS_error");
          obj->log_error();
          break;
      }
    }
    if((ret_code!=TLS__op__result::TLS__ERROR)){
      int pending = BIO_ctrl_pending(obj->out_bio);
      if(pending>0){
        TTCN_Buffer buff;
        size_t s=pending;
        unsigned char* ptr;
        buff.get_end(ptr,s);
        int rl=BIO_read(obj->out_bio, ptr, s);
        buff.increase_length(rl);
        buff.get_string(output__stream);
        ret_code=TLS__op__result::TLS__DATA__TO__SEND;
      } else {
        output__stream=OCTETSTRING(0,NULL);
      }
    }
    return ret_code;
  }
  TTCN_warning("Invalid object_id");
  return TLS__op__result::TLS__ERROR;
  
}
TLS__op__result TLS__Write(INTEGER const& object__id, OCTETSTRING const& user__data, OCTETSTRING const& input__stream, OCTETSTRING& output__stream){
  if(is_valid_objid(object__id)){
    int idx=object__id;
    TLS_object* obj=obj_list[idx];
    TLS__op__result ret_code=TLS__op__result::TLS__ERROR;
    if(input__stream.lengthof()>0){
      BIO_write(obj->in_bio, (const unsigned char*)input__stream, input__stream.lengthof());
    }
    
    int res=SSL_write(obj->ssl,(const unsigned char*)user__data, user__data.lengthof());
     if(res>0){
      ret_code=TLS__op__result::TLS__OK;
    } else {
      // Check the error code
      switch(SSL_get_error(obj->ssl,res)){
        case SSL_ERROR_WANT_READ:
          ret_code=TLS__op__result::TLS__NEED__MORE__DATA;
          break;
        case SSL_ERROR_WANT_WRITE:
          ret_code=TLS__op__result::TLS__DATA__TO__SEND;
          break;
        default:
          ret_code=TLS__op__result::TLS__ERROR;
          TTCN_warning("TLS_error");
          obj->log_error();
          break;
      }
    }
    if((ret_code!=TLS__op__result::TLS__ERROR)){
      int pending = BIO_ctrl_pending(obj->out_bio);
      if(pending>0){
        TTCN_Buffer buff;
        size_t s=pending;
        unsigned char* ptr;
        buff.get_end(ptr,s);
        int rl=BIO_read(obj->out_bio, ptr, s);
        buff.increase_length(rl);
        buff.get_string(output__stream);
//        ret_code=TLS__op__result::TLS__DATA__TO__SEND;
      } else {
        output__stream=OCTETSTRING(0,NULL);
      }
    }
    return ret_code;
   
  }
  TTCN_warning("Invalid object_id");
  return TLS__op__result::TLS__ERROR;
}
TLS__op__result TLS__Read(INTEGER const& object__id, OCTETSTRING& user__data, OCTETSTRING const& input__stream, OCTETSTRING& output__stream){
  if(is_valid_objid(object__id)){
    int idx=object__id;
    TLS_object* obj=obj_list[idx];
    TLS__op__result ret_code=TLS__op__result::TLS__ERROR;
    if(input__stream.lengthof()>0){
      BIO_write(obj->in_bio, (const unsigned char*)input__stream, input__stream.lengthof());
    }
    unsigned char buff[8096];
    int res=SSL_read(obj->ssl,buff, sizeof(buff));
    if(res>0){
      user__data=OCTETSTRING(res,buff);
      ret_code=TLS__op__result::TLS__OK;
    } else {
      user__data=OCTETSTRING(0,NULL);
      // Check the error code
      switch(SSL_get_error(obj->ssl,res)){
        case SSL_ERROR_WANT_READ:
          ret_code=TLS__op__result::TLS__NEED__MORE__DATA;
          break;
        case SSL_ERROR_WANT_WRITE:
          ret_code=TLS__op__result::TLS__DATA__TO__SEND;
          break;
        default:
          ret_code=TLS__op__result::TLS__ERROR;
          TTCN_warning("TLS_error");
          obj->log_error();
          break;
      }
    }
    if((ret_code!=TLS__op__result::TLS__ERROR)){
      int pending = BIO_ctrl_pending(obj->out_bio);
      if(pending>0){
        TTCN_Buffer buff;
        size_t s=pending;
        unsigned char* ptr;
        buff.get_end(ptr,s);
        int rl=BIO_read(obj->out_bio, ptr, s);
        buff.increase_length(rl);
        buff.get_string(output__stream);
//        ret_code=TLS__op__result::TLS__DATA__TO__SEND;
      } else {
        output__stream=OCTETSTRING(0,NULL);
      }
    }
    return ret_code;
   
  }
  TTCN_warning("Invalid object_id");
  return TLS__op__result::TLS__ERROR;
}

}

