用 CRTP 技巧实现通用的 IUnknown 实现
2024-03-17 17:33:41

最近在写一个 COM 组件,没有用 ATL 模板,而是纯C++实现。
碰到一件麻烦事,每个接口的实现类都要反复实现 AddRefReleaseQueryInterface 方法。

比如有两个接口

1
2
struct IA : IUnknown {}
struct IB : IUnknown {}

那么每个类都要实现一次 IUnknown 接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class A : public IA
{
virtual HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void** ppv) { }
virtual ULONG STDMETHODCALLTYPE AddRef() { }
virtual ULONG STDMETHODCALLTYPE Release() { }
ULONG _ref_count;
}

class B : public IB
{
virtual HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void** ppv) { }
virtual ULONG STDMETHODCALLTYPE AddRef() { }
virtual ULONG STDMETHODCALLTYPE Release() { }
ULONG _ref_count;
}

在查阅了一些资料后,发现 CRTP 技巧可以不错的解决计数器代码复用的需求

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#pragma once

#include <vector>
#include <Windows.h>

template<class T>
class RefCounted
{
public:
RefCounted()
{
_ref_count = 1;
//::InterlockedIncrement(&g_obj_count);
}

virtual ~RefCounted()
{
//::InterlockedDecrement(&g_obj_count);
}

HRESULT Query(REFIID riid, void** ppv)
{
if (riid == IID_IUnknown)
{
*ppv = static_cast<T*>(this);
static_cast<IUnknown*>(*ppv)->AddRef();
return S_OK;
}

HRESULT hr = E_NOINTERFACE;
for (std::size_t i = 0; i < _interfaces.size(); ++i)
{
if (riid == _interfaces[i].first)
{
*ppv = _interfaces[i].second;
static_cast<IUnknown*>(*ppv)->AddRef();
hr = S_OK;
break;
}
}

return hr;
}

ULONG Increment()
{
return InterlockedIncrement(reinterpret_cast<volatile LONG*>(&_ref_count));
}

ULONG Decrement()
{
const ULONG ref_count = InterlockedDecrement(reinterpret_cast<volatile LONG*>(&_ref_count));
if (ref_count == 0)
delete static_cast<T*>(this);
return ref_count;
}

protected:
void AddInterface(REFIID iid, void* p)
{
_interfaces.push_back(std::make_pair(iid, p));
}

private:
std::vector<std::pair<IID, void*>> _interfaces;
volatile ULONG _ref_count;
};

当一个派生类要实现计数器时这么用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class MyFile : 
private RefCounted<MyFile>,
public IRead,
public IWrite
{
typedef RefCounted<MyFile> super;
public:
MyFile()
{
// 注册支持的接口
super::AddInterface(__uuidof(IRead), static_cast<IRead*>(this));
super::AddInterface(__uuidof(IWrite), static_cast<IWrite*>(this));
}

/* 实现 IUnknown */
virtual HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void** ppv) { return super::Query(riid, ppv); }
virtual ULONG STDMETHODCALLTYPE AddRef() { return super::Increment(); }
virtual ULONG STDMETHODCALLTYPE Release() { return super::Decrement(); }

/* 实现 IRead */
/* 实现 IWrite */
};

通过 CRTP 技巧,将计数器的实现隐藏了起来,而派生类只需要关心业务接口的实现即可。