#ifndef ___THREADING_H_INCLUDED___ #define ___THREADING_H_INCLUDED___ #include <windows.h> #include "mem_manager.h" /////////////////////////////////////////////////////////////////////////////////////////////// //---------------------------- CLASS ---------------------------------------------------------- class CThread; class CSimpleThreadPool; /////////////////////////////////////////////////////////////////////////////////////////////// // The core interface of this threading implementation. class IRunnable { friend class CThread; friend class CSimpleThreadPool; protected: virtual void run() = 0; }; /////////////////////////////////////////////////////////////////////////////////////////////// //---------------------------- CLASS ---------------------------------------------------------- // structure holds Thread Local Storage descriptor struct TLSDescriptor { DWORD descriptor; TLSDescriptor() { descriptor = TlsAlloc(); if (descriptor == -1) throw "TlsAlloc() failed to create descriptor"; }; ~TLSDescriptor() { TlsFree(descriptor); }; }; /////////////////////////////////////////////////////////////////////////////////////////////// //---------------------------- CLASS ---------------------------------------------------------- // Very basic thread class, implementing basic threading API class CThread: public IRunnable { private: static TLSDescriptor m_TLSDesc; volatile bool m_bIsInterrupted; volatile bool m_bIsRunning; int m_nThreadPriority; IRunnable *m_RunObj; QMutex m_qMutex; // See ::CreateThread(...) within the start() method. This is // the thread's API function to be executed. Method executes // run() method of the CThread instance passed as parameter. static DWORD WINAPI StartThreadST(LPVOID PARAM) { CThread *_this = (CThread *) PARAM; if (_this != NULL) { _this->m_qMutex.Lock(); // Set the pointer to the instance of the passed CThread // in the current Thread's Local Storage. Also see // currentThread() method. TlsSetValue(CThread::m_TLSDesc.descriptor, (LPVOID) _this); _this->run(); _this->m_bIsRunning = false; _this->m_qMutex.Unlock(); } return 0; }; protected: // It is not possible to instantiate CThread objects directly. Possible only by // specifying a IRunnable object to execute its run() method. CThread(int nPriority = THREAD_PRIORITY_NORMAL): m_qMutex() { this->m_bIsInterrupted = false; this->m_bIsRunning = false; this->m_nThreadPriority = nPriority; this->m_RunObj = NULL; }; // this implementation of the run() will execute the passed IRunnable // object (if not null). Inheriting class is responsible for using this // method or overriding it with a different one. virtual void run() { if (this->m_RunObj != NULL) this->m_RunObj->run(); }; public: CThread(IRunnable *RunTask, int nPriority = THREAD_PRIORITY_NORMAL): m_qMutex() { this->m_bIsInterrupted = false; this->m_bIsRunning = false; this->m_nThreadPriority = nPriority; if (this != RunTask) this->m_RunObj = RunTask; else throw "Self referencing not allowed."; }; virtual ~CThread() { this->interrupt(); // wait until thread ends this->join(); }; // Method returns the instance of a CThread responsible // for the context of the current thread. static CThread& currentThread() { CThread *thr = (CThread *) TlsGetValue(CThread::m_TLSDesc.descriptor); if (thr == NULL) throw "Call is not within a CThread context."; return *thr; }; // Method signals thread to stop execution. void interrupt() { this->m_bIsInterrupted = true; }; // Check if thread was signaled to stop. bool isInterrupted() { return this->m_bIsInterrupted; }; // Method will wait for thread's termination. void join() { this->m_qMutex.Lock(); this->m_qMutex.Unlock(); }; // Method starts the Thread. If thread is already started/running, method // will simply return. void start() { HANDLE hThread; LPTHREAD_START_ROUTINE pStartRoutine = &CThread::StartThreadST; if (this->m_qMutex.TryLock()) { if (!this->m_bIsRunning) { this->m_bIsRunning = true; this->m_bIsInterrupted = false; hThread = ::CreateThread(NULL, 0, pStartRoutine, (PVOID) this, 0, NULL); if (hThread == NULL) { this->m_bIsRunning = false; this->m_qMutex.Unlock(); throw "Failed to call CreateThread(). Thread not started."; } ::SetThreadPriority(hThread, this->m_nThreadPriority); ::CloseHandle(hThread); } this->m_qMutex.Unlock(); } }; }; TLSDescriptor CThread::m_TLSDesc; /////////////////////////////////////////////////////////////////////////////////////////////// //---------------------------- CLASS ---------------------------------------------------------- // Helper class to submit tasks to the CSimpleThreadPool class CPriorityTask { private: int m_nPriority; IRunnable *m_pRunObj; public: CPriorityTask(const CPriorityTask &t) { m_pRunObj = t.m_pRunObj; m_nPriority = t.m_nPriority; }; CPriorityTask() { m_pRunObj = NULL; m_nPriority = 0; }; CPriorityTask(IRunnable *pRunObj, int nPriority = 0) { m_pRunObj = pRunObj; m_nPriority = nPriority; }; int getPriority() const { return m_nPriority; }; IRunnable *getTask() const { return m_pRunObj; }; ~CPriorityTask() {}; CPriorityTask& operator=(const CPriorityTask &t) { m_nPriority = t.m_nPriority; m_pRunObj = t.m_pRunObj; return *this; }; }; //Overload the < operator. bool operator< (const CPriorityTask& pt1, const CPriorityTask& pt2) { return pt1.getPriority() < pt2.getPriority(); } //Overload the > operator. bool operator> (const CPriorityTask& pt1, const CPriorityTask& pt2) { return pt1.getPriority() > pt2.getPriority(); } /////////////////////////////////////////////////////////////////////////////////////////////// //---------------------------- CLASS ---------------------------------------------------------- // A class containing a collection of CThreadTask's. // Every CThreadTask will execute same CSimpleThreadPool::run() method. class CSimpleThreadPool: public IRunnable { private: QMutex m_qMutex; vector<CThread*> m_arrThreadTasks; mpriority_queue<CPriorityTask> m_PQueue; // Method will return a task from the queue, // if there are no tasks in the queue, method will return NULL. IRunnable *get() { IRunnable *ret = NULL; m_qMutex.Lock(); if (!m_PQueue.empty()) { CPriorityTask t = m_PQueue.top(); m_PQueue.pop(); ret = t.getTask(); } m_qMutex.Unlock(); return ret; }; public: // How many threads are in the collection. int threads() const { m_arrThreadTasks.size(); }; // Method starts pool's threads. void startAll() { for (unsigned int i = 0; i < m_arrThreadTasks.size(); i++) { m_arrThreadTasks[i]->start(); } }; // Constructor creates the thread pool and sets capacity for the task queue. CSimpleThreadPool(unsigned int nThreadsCount, unsigned int nQueueCapacity = 16): m_qMutex(), m_PQueue() { unsigned int i; CThread *thTask = NULL; if (nThreadsCount <= 0) throw "Invalid number of threads supplied."; if (nQueueCapacity <= 0) throw "Invalid capacity supplied."; // Set initial capacity of the tasks Queue. m_PQueue.reserve(nQueueCapacity); // Initialize thread pool. for (i = 0; i < nThreadsCount; i++) { thTask = new CThread(this); if (thTask != NULL) m_arrThreadTasks.push_back(thTask); } }; // Submit a new task to the pool void submit(IRunnable *pRunObj, int nPriority = 0) { if (this == pRunObj) throw "Self referencing not allowed."; m_qMutex.Lock(); m_PQueue.push(CPriorityTask(pRunObj, nPriority)); m_qMutex.Unlock(); }; // Method will execute task's run() method within its CThread context. virtual void run() { IRunnable *task; while (!CThread::currentThread().isInterrupted()) { // Get a task from the queue. task = get(); // Execute the task. if (task != NULL) task->run(); ::Sleep(2); } }; virtual ~CSimpleThreadPool() { vector<CThread*>::iterator itPos = m_arrThreadTasks.begin(); for (; itPos < m_arrThreadTasks.end(); itPos++) delete *itPos; m_arrThreadTasks.clear(); while (!m_PQueue.empty()) { m_PQueue.pop(); } }; // Method signals threads to stop and waits for termination void shutdown() { for (unsigned int i = 0; i < m_arrThreadTasks.size(); i++) { m_arrThreadTasks[i]->interrupt(); m_arrThreadTasks[i]->join(); } }; }; #endif