/*******************************************************************************
 * Copyright (c) 2009, 2011 Nokia and others.
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 *
 * Contributors:
 * Nokia - Initial API and implementation
 *******************************************************************************/
#include <sstream>
#include <stdio.h>
#include <assert.h>

#include "stdafx.h"
#include "WinThread.h"
#include "WinProcess.h"
#include "AgentUtils.h"
#include "EventClientNotifier.h"
#include "Logger.h"
#include "WinDebugMonitor.h"
#include "ResumeContextAction.h"
#include "ProtocolConstants.h"
#include "RunControlService.h"
#include "BreakpointsService.h"

std::map<std::pair<int, int>, WinThread*> WinThread::threadIDMap_;

WinThread::WinThread(WinProcess& process, DEBUG_EVENT& debugEvent)
  : ThreadContext(debugEvent.dwThreadId, process.GetID(), CreateInternalID(debugEvent.dwThreadId, process.GetID())),
	threadLookupPair_(debugEvent.dwProcessId, debugEvent.dwThreadId),
	parentProcess_(process)
{
	process.AddChild(this);

	threadIDMap_[threadLookupPair_] = this;

	threadContextValid_ = false;
	if (debugEvent.dwDebugEventCode == CREATE_PROCESS_DEBUG_EVENT) {
		handle_ = debugEvent.u.CreateProcessInfo.hThread;
		startAddress_
				= (unsigned long) debugEvent.u.CreateProcessInfo.lpStartAddress;
		localBase_ = debugEvent.u.CreateProcessInfo.lpThreadLocalBase;
	} else if (debugEvent.dwDebugEventCode == CREATE_THREAD_DEBUG_EVENT) {
		handle_ = debugEvent.u.CreateThread.hThread;
		startAddress_
				= (unsigned long) debugEvent.u.CreateThread.lpStartAddress;
		localBase_ = debugEvent.u.CreateThread.lpThreadLocalBase;
	}
	isSuspended_ = false;
	isTerminating_ = false;
	isUserSuspended_ = false;

	// just to ensure that new threads are resumed with DBG_CONTINUE.  this is normally set/changed
	// in HandleException
	exceptionInfo_.ExceptionRecord.ExceptionCode = USER_SUSPEND_THREAD;

	Initialize();
}

// Initialize thread specific properties.
void WinThread::Initialize() {
	char buf[32];
	::_snprintf(buf, sizeof(buf), "0x%08x", startAddress_);
	SetProperty(PROP_NAME, PropertyValue(buf));

	int supportedResumeModes = (1 << RM_RESUME) | (1 << RM_STEP_INTO);
	SetProperty(PROP_CAN_RESUME, PropertyValue(supportedResumeModes));

	SetProperty(PROP_CAN_TERMINATE, PropertyValue(true));
	SetProperty(PROP_CAN_SUSPEND, PropertyValue(true));
}

int WinThread::GetThreadID() {
	return GetOSID();
}

WinThread::~WinThread(void) {
	parentProcess_.RemoveChild(this);
	threadIDMap_.erase(threadLookupPair_);

	// Destructor of parent classes will be called which will
	// delete all children contexts (registers, etc).
}

ContextAddress WinThread::GetPCAddress() const {
	// The following is actually the address of the instruction that causes
	// the exception, not the actual PC register value which is usually 
	// pointing to the byte after the exception instruction.
	// But what we need here is PC value.
	//
	// exceptionInfo_.ExceptionRecord.ExceptionAddress;

	assert(threadContextValid_);
	return threadContextInfo_.Eip;
}

const char* WinThread::GetSuspendReason() const {
	const char* reason = REASON_EXCEPTION;

	switch (exceptionInfo_.ExceptionRecord.ExceptionCode) {
	case USER_SUSPEND_THREAD:
		return REASON_USER_REQUEST;
	case EXCEPTION_SINGLE_STEP:
		return REASON_STEP;
	case EXCEPTION_BREAKPOINT:
		return REASON_BREAKPOINT;
	}

	return reason;
}

DWORD WinThread::GetContinueStatus() const {
	// for resuming from any exception other than breakpoint or step (which the debugger handles), we must pass
	// DBG_EXCEPTION_NOT_HANDLED to allow the process under debug the chance to handle it
	switch (exceptionInfo_.ExceptionRecord.ExceptionCode) {
	case USER_SUSPEND_THREAD:
	case EXCEPTION_SINGLE_STEP:
	case EXCEPTION_BREAKPOINT:
		return DBG_CONTINUE;
	}

	return DBG_EXCEPTION_NOT_HANDLED;
}

std::string WinThread::GetExceptionMessage() const {
	if (exceptionInfo_.ExceptionRecord.ExceptionCode == EXCEPTION_SINGLE_STEP
			|| exceptionInfo_.ExceptionRecord.ExceptionCode == EXCEPTION_BREAKPOINT
			|| exceptionInfo_.ExceptionRecord.ExceptionCode == USER_SUSPEND_THREAD)
		return "";

	return WinDebugMonitor::GetDebugExceptionDescription(exceptionInfo_);
}

void WinThread::MarkSuspended() {
	isSuspended_ = true;
	threadContextValid_ = false;
}

void WinThread::HandleException(DEBUG_EVENT& debugEvent) {
	exceptionInfo_ = debugEvent.u.Exception;
	MarkSuspended();
	EnsureValidContextInfo();

	if (threadContextInfo_.Dr6 & 0xF) {	// one will be set if a HW bkpt triggered
		HandleHardwareBreak();

		// always reset to this for the next time through.
		threadContextInfo_.Dr6 = 0xFFFF0FF0;
		::SetThreadContext(handle_, &threadContextInfo_);
	} else {
		AdjustPC();
		EventClientNotifier::SendContextSuspended(this,
				GetPCAddress(), GetSuspendReason(), GetExceptionMessage());
	}
}

/*
 * Check if the program is stopped due to a software breakpoint
 * installed by the agent, if yes, move PC back by one byte.
 */
void WinThread::AdjustPC() {
	// Bail out if the agent does not install & manage
	// breakpoints (namely the EDC host uses generic
	// software breakpoint mechanism).
	if (! BreakpointsService::ServiceInstalled())
		return;

	/*
	 * Check
	 * 1. Did we stop due to a breakpoint exception ?
	 *   -- This is to prevent adjusting PC for other exceptions such as
	 *      divide-by-zero & invalid code.
	 * 2. is there a software breakpoint at the byte right before the PC?
	 *   -- this is to exclude the case of user-inserted "int 3" instruction.
	 */
	if (exceptionInfo_.ExceptionRecord.ExceptionCode != EXCEPTION_BREAKPOINT)
		return;

	ContextAddress pc = GetPCAddress();
	pc--;
	if (NULL != BreakpointsService::FindBreakpointByAddress(parentProcess_.GetProcessHandle(), pc)) {
		SetRegisterValue("EIP", 4, (char*)&pc);
	}
}

void WinThread::HandleExecutableEvent(bool isLoaded, const std::string& exePath,
		unsigned long baseAddress, unsigned long codeSize) {
	MarkSuspended();
	EnsureValidContextInfo();

	Properties props;
	if (isLoaded)
	{
		props[PROP_ID] = PropertyValue((int) baseAddress);
		props[PROP_FILE] = PropertyValue(exePath);
		props[PROP_NAME] = PropertyValue(AgentUtils::GetFileNameFromPath(exePath));
		props[PROP_MODULE_LOADED] = PropertyValue(isLoaded);
		props[PROP_IMAGE_BASE_ADDRESS] = PropertyValue((int) baseAddress);
		props[PROP_CODE_SIZE] = PropertyValue((int) codeSize);
		parentProcess_.GetExecutablesByAddress()[baseAddress] = props;
	}
	else
	{
		props = parentProcess_.GetExecutablesByAddress()[baseAddress];
		assert(!props.empty());
		props[PROP_MODULE_LOADED] = PropertyValue(false);

		// the executable is unloaded so remove it from our list of executables
		// otherwise if another executable is later loaded into the same address
		// and later unloaded, we'll send a bogus unloaded event
		parentProcess_.GetExecutablesByAddress().erase(baseAddress);
	}

	EventClientNotifier::SendExecutableEvent(this,
			threadContextInfo_.Eip, props);
}

bool WinThread::IsSuspended() const {
	return isSuspended_;
}

#ifndef CONTEXT_ALL
#define CONTEXT_ALL             (CONTEXT_CONTROL | CONTEXT_INTEGER | CONTEXT_SEGMENTS | \
	CONTEXT_FLOATING_POINT | CONTEXT_DEBUG_REGISTERS | \
	CONTEXT_EXTENDED_REGISTERS)
#endif

void WinThread::EnsureValidContextInfo() {
	if (!threadContextValid_ && IsSuspended()) {
		threadContextInfo_.ContextFlags = CONTEXT_ALL;
		if (::GetThreadContext(handle_, &threadContextInfo_) != 0) {
			registerValueCache_.clear();
			// Cache general registers
			registerValueCache_["EAX"]
			  = AgentUtils::IntToHexString(threadContextInfo_.Eax);
			registerValueCache_["ECX"]
			  = AgentUtils::IntToHexString(threadContextInfo_.Ecx);
			registerValueCache_["EDX"]
			  = AgentUtils::IntToHexString(threadContextInfo_.Edx);
			registerValueCache_["EBX"]
			  = AgentUtils::IntToHexString(threadContextInfo_.Ebx);
			registerValueCache_["ESP"]
			  = AgentUtils::IntToHexString(threadContextInfo_.Esp);
			registerValueCache_["EBP"]
			  = AgentUtils::IntToHexString(threadContextInfo_.Ebp);
			registerValueCache_["ESI"]
			  = AgentUtils::IntToHexString(threadContextInfo_.Esi);
			registerValueCache_["EDI"]
			  = AgentUtils::IntToHexString(threadContextInfo_.Edi);
			registerValueCache_["EIP"]
			  = AgentUtils::IntToHexString(threadContextInfo_.Eip);
			registerValueCache_["GS"]
			  = AgentUtils::IntToHexString(threadContextInfo_.SegGs);
			registerValueCache_["FS"]
			  = AgentUtils::IntToHexString(threadContextInfo_.SegFs);
			registerValueCache_["ES"]
			  = AgentUtils::IntToHexString(threadContextInfo_.SegEs);
			registerValueCache_["DS"]
			  = AgentUtils::IntToHexString(threadContextInfo_.SegDs);
			registerValueCache_["CS"]
			  = AgentUtils::IntToHexString(threadContextInfo_.SegCs);
			registerValueCache_["EFL"]
			  = AgentUtils::IntToHexString(threadContextInfo_.EFlags);
			registerValueCache_["SS"]
			  = AgentUtils::IntToHexString(threadContextInfo_.SegSs);

			threadContextValid_ = true;
		}
	}
}

void WinThread::SetContextInfo() {
	if (IsSuspended()) {
		threadContextInfo_.ContextFlags = CONTEXT_ALL;
		// Set general registers values
		threadContextInfo_.Eax
		  = AgentUtils::HexStringToInt(registerValueCache_["EAX"]);
		threadContextInfo_.Ecx
		  = AgentUtils::HexStringToInt(registerValueCache_["ECX"]);
		threadContextInfo_.Edx
		  = AgentUtils::HexStringToInt(registerValueCache_["EDX"]);
		threadContextInfo_.Ebx
		  = AgentUtils::HexStringToInt(registerValueCache_["EBX"]);
		threadContextInfo_.Esp
		  = AgentUtils::HexStringToInt(registerValueCache_["ESP"]);
		threadContextInfo_.Ebp
		  = AgentUtils::HexStringToInt(registerValueCache_["EBP"]);
		threadContextInfo_.Esi
		  = AgentUtils::HexStringToInt(registerValueCache_["ESI"]);
		threadContextInfo_.Edi
		  = AgentUtils::HexStringToInt(registerValueCache_["EDI"]);
		threadContextInfo_.Eip
		  = AgentUtils::HexStringToInt(registerValueCache_["EIP"]);
		threadContextInfo_.SegGs
		  = AgentUtils::HexStringToInt(registerValueCache_["GS"]);
		threadContextInfo_.SegFs
		  = AgentUtils::HexStringToInt(registerValueCache_["FS"]);
		threadContextInfo_.SegEs
		  = AgentUtils::HexStringToInt(registerValueCache_["ES"]);
		threadContextInfo_.SegDs
		  = AgentUtils::HexStringToInt(registerValueCache_["DS"]);
		threadContextInfo_.SegCs
		  = AgentUtils::HexStringToInt(registerValueCache_["CS"]);
		threadContextInfo_.EFlags
		  = AgentUtils::HexStringToInt(registerValueCache_["EFL"]);
		threadContextInfo_.SegSs
		  = AgentUtils::HexStringToInt(registerValueCache_["SS"]);
		::SetThreadContext(handle_, &threadContextInfo_);
	}
}

WinThread* WinThread::GetThreadByID(int processID, int threadID) {
	std::pair<int, int> ptPair(processID, threadID);
	std::map<std::pair<int, int>, WinThread*>::iterator iter = threadIDMap_.find(ptPair);
	if (iter == threadIDMap_.end())
		return NULL;

	return iter->second;
}

std::vector<std::string> WinThread::GetRegisterValues(
		const std::vector<std::string>& registerIDs) {
	std::vector<std::string> registerValues;

	if (IsSuspended()) {
		EnsureValidContextInfo();

		std::vector<std::string>::const_iterator itVectorData;
		for (itVectorData = registerIDs.begin(); itVectorData
				!= registerIDs.end(); itVectorData++) {
			std::string registerID = *itVectorData;
			std::string registerValue = registerValueCache_[registerID];
			registerValues.push_back(registerValue);
		}
	}

	return registerValues;
}

/*
 * Get pointer to register value cache for a given register.
 * Return NULL if the register is not found.
 */
void* WinThread::GetRegisterValueBuffer(const std::string& regName) const {
	void* v = NULL;

	if (regName == "EAX")
		v = (void*)&threadContextInfo_.Eax;
	else if (regName == "EBX")
		v = (void*)&threadContextInfo_.Ebx;
	else if (regName == "ECX")
		v = (void*)&threadContextInfo_.Ecx;
	else if (regName == "EDX")
		v = (void*)&threadContextInfo_.Edx;
	else if (regName == "ESP")
		v = (void*)&threadContextInfo_.Esp;
	else if (regName == "EBP")
		v = (void*)&threadContextInfo_.Ebp;
	else if (regName == "ESI")
		v = (void*)&threadContextInfo_.Esi;
	else if (regName == "EDI")
		v = (void*)&threadContextInfo_.Edi;
	else if (regName == "EIP")
		v = (void*)&threadContextInfo_.Eip;
	else if (regName == "EFL")
		v = (void*)&threadContextInfo_.EFlags;
	else if (regName == "GS")
		v = (void*)&threadContextInfo_.SegGs;
	else if (regName == "FS")
		v = (void*)&threadContextInfo_.SegFs;
	else if (regName == "ES")
		v = (void*)&threadContextInfo_.SegEs;
	else if (regName == "DS")
		v = (void*)&threadContextInfo_.SegDs;
	else if (regName == "CS")
		v = (void*)&threadContextInfo_.SegCs;
	else if (regName == "SS")
		v = (void*)&threadContextInfo_.SegSs;
	else {
		assert(false);
	}

	return v;
}

/*
 * Read one register.
 * Return binary data buffer, which caller should free by calling delete[].
 */
char* WinThread::GetRegisterValue(const std::string& regName, int regSize) {

	char* ret = NULL;

	if (IsSuspended()) {
		EnsureValidContextInfo();

		ret = new char[regSize];

		void* v = GetRegisterValueBuffer(regName);
		assert(v != NULL);

		::memcpy((void*)ret, v, regSize);
	}

	return ret;
}

bool WinThread::SetRegisterValue(const std::string& regName, int regSize, char* val) {

	if (! IsSuspended())
		return false;

	void* v = GetRegisterValueBuffer(regName);
	assert(v != NULL);

	::memcpy(v, (void*)val, regSize);
	return ::SetThreadContext(handle_, &threadContextInfo_);
}

void WinThread::SetRegisterValues(const std::vector<std::string>& registerIDs,
		const std::vector<std::string>& registerValues) {
	if (IsSuspended()) {
		std::vector<std::string>::const_reverse_iterator itVectorData;
		int idx = registerValues.size();
		for (itVectorData = registerIDs.rbegin(); itVectorData
				!= registerIDs.rend(); itVectorData++) {
			std::string registerID = *itVectorData;
			registerValueCache_[registerID] = registerValues[--idx];
		}

		SetContextInfo();
	}
}

int WinThread::ReadMemory(const ReadWriteMemoryParams& params) throw (AgentException) {
	return parentProcess_.ReadMemory(params);
}

int WinThread::WriteMemory(const ReadWriteMemoryParams& params) throw (AgentException) {
	return parentProcess_.WriteMemory(params);
}

void WinThread::Terminate(const AgentActionParams& params) throw (AgentException) {
	parentProcess_.Terminate(params);
}

DWORD WinThread::Suspend() {
	DWORD suspendCount = ::SuspendThread(handle_);
	MarkSuspended();
	EnsureValidContextInfo();
	exceptionInfo_.ExceptionRecord.ExceptionCode = USER_SUSPEND_THREAD; // "Suspended"
	isUserSuspended_ = true;
	return suspendCount;
}

void WinThread::Suspend(const AgentActionParams& params) throw (AgentException)
{
	DWORD suspendCount = Suspend();
	if (! isTerminating_)	// don't send Suspend event if we are terminating.
		EventClientNotifier::SendContextSuspended(this,
				GetPCAddress(), GetSuspendReason(), GetExceptionMessage());
	Logger::getLogger().Log(Logger::LOG_NORMAL, "WinThread::Suspend",
			"suspendCount: %d", suspendCount);

	params.reportSuccessForAction();
}

void WinThread::Resume() {
	if (! IsSuspended() || !isUserSuspended_)
		return;
	::ResumeThread(handle_);
	isUserSuspended_ = false;
}

void WinThread::Resume(const AgentActionParams& params) throw (AgentException) {
	if (! IsSuspended()) {
		params.reportSuccessForAction();
		return;
	}

	if (isUserSuspended_) {
		Resume();
		params.reportSuccessForAction();
	} else {
		parentProcess_.GetMonitor()->PostAction(new ResumeContextAction(
			params, parentProcess_, *this, RM_RESUME));
	}
}

/*
 * Enable single instruction step by setting Trap Flag (TF) bit.
 */
void WinThread::EnableSingleStep() {
#define FLAG_TRACE_BIT 0x100
	// The bit will be auto-cleared after next resume.
	threadContextInfo_.EFlags |= FLAG_TRACE_BIT;
	::SetThreadContext(handle_, &threadContextInfo_);
}

void WinThread::SingleStep(const AgentActionParams& params) throw (AgentException) {
	parentProcess_.GetMonitor()->PostAction(new ResumeContextAction(
			params, parentProcess_, *this, RM_STEP_INTO));
}

void WinThread::PrepareForTermination(const AgentActionParams& params) throw (AgentException) {
	isTerminating_ = true;

	if (IsSuspended()) {
		Suspend(params);
		::ContinueDebugEvent(parentProcess_.GetOSID(), GetOSID(), DBG_CONTINUE);
	}
}

void WinThread::SetDebugRegister(WinHWBkptMgr::DRMask dRegs, ContextAddress addr, WinHWBkptMgr::DRFlags dr7bits) {
	bool suspended = isSuspended_;
	if (!suspended)
		Suspend();
	DWORD& DR7 = threadContextInfo_.Dr7;
	switch (dRegs) {
	case DR0:
		threadContextInfo_.Dr0 = addr;
		DR7 = (DR7 & ~0x000F0003) | (dr7bits<<16) | 0x01;
		break;
	case DR1:
		threadContextInfo_.Dr1 = addr;
		DR7 = (DR7 & ~0x00F0000C) | (dr7bits<<20) | 0x04;
		break;
	case DR2:
		threadContextInfo_.Dr2 = addr;
		DR7 = (DR7 & ~0x0F000030) | (dr7bits<<24) | 0x10;
		break;
	case DR3:
		threadContextInfo_.Dr3 = addr;
		DR7 = (DR7 & ~0xF0000030) | (dr7bits<<28) | 0x40;
		break;
	default:
		return;
	}
	::SetThreadContext(handle_, &threadContextInfo_);
	if (!suspended)
		Resume();
}

void WinThread::ClearDebugRegister(WinHWBkptMgr::DRMask singleDReg) {
	bool suspended = isSuspended_;
	if (!suspended)
		Suspend();
	DWORD& DR7 = threadContextInfo_.Dr7;
	switch (singleDReg)	{
	case DR0:		DR7 &= ~0x000F0003;		break;
	case DR1:		DR7 &= ~0x00F0000C;		break;
	case DR2:		DR7 &= ~0x0F000030;		break;
	case DR3:		DR7 &= ~0xF00000C0;		break;
	default:
		return;
	}
	::SetThreadContext(handle_, &threadContextInfo_);
	::SetThreadContext(handle_, &threadContextInfo_);
	if (!suspended)
		Resume();
}

void WinThread::HandleHardwareBreak() {
	const HANDLE& procHandle = parentProcess_.GetProcessHandle();
	ContextAddress* dr = (ContextAddress*)&(threadContextInfo_.Dr0);
	for (int i = 0; i < 4; ++i) {
		if (threadContextInfo_.Dr6 & (1 << i)) {
			TBreakpoint* wp
			  = BreakpointsService::FindWatchpointByAddress(procHandle, dr[i]);
			if (wp) {
				std::ostringstream wpAddr; wpAddr << std::hex << wp->address;
				EventClientNotifier::SendContextSuspended(this,
						GetPCAddress(), REASON_WATCHPOINT, wpAddr.str());
				break;
			}
		}
	}
}
