Featured image of post C++ 智能指针实现之shared_ptr

C++ 智能指针实现之shared_ptr

智能指针本质上就是利用 RAII 资源管理功能,本文介绍实现 C++中智能指针的 shared_ptr。

前言

智能指针本质上并不神秘,其实就是 RAII 资源管理功能的自然展现而已。本文将介绍如何实现 C++中智能指针的 shared_ptr。

原理简介

多个不同的 shared_ptr 不仅可以共享一个对象,在共享同一对象时也需要同时共享同一个计数。当最后一个指向对象(和共享计数)的 shared_ptr 析构时,它需要删除对象和共享计数。

实现过程

我们先实现共享计数的接口,这个 shared_count 类除构造函数之外有三个方法:一个增加计数,一个减少计数,一个获取计数。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class shared_count {
public:
  shared_count() : count_(1) {}
  void add_count()
  {
    ++count_;
  }
  long reduce_count()
  {
    return --count_;
  }
  long get_count() const
  {
    return count_;
  }

private:
  long count_;
};

接下来可以实现我们的引用计数智能指针了。首先是构造函数、析构函数和私有成员变量:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
template <typename T>
class shared_ptr {
public:
  explicit shared_ptr(T* ptr = nullptr)
    : ptr_(ptr)
  {
    if (ptr) {
      shared_count_ = new shared_count();
    }
  }
  ~shared_ptr()
  {
    if (ptr_ && !shared_count_->reduce_count()) {
      delete ptr_;
      delete shared_count_;
    }
  }

private:
  T* ptr_;
  shared_count* shared_count_;
};

构造函数会构造一个 shared_count 出来。析构函数在看到 ptr_ 非空时(此时根据代码逻辑, shared_count 也必然非空),需要对引用数减一,并在引用数降到零时彻底删除对象和共享计数。

为了方便实现赋值(及其他一些惯用法),我们需要一个新的 swap 成员函数:

1
2
3
4
5
6
void swap(shared_ptr& rhs)
{
  using std::swap;
  swap(ptr_, rhs.ptr_);
  swap(shared_count_, rhs.shared_count_);
}

赋值函数,拷贝构造和移动构造函数的实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
shared_ptr(const shared_ptr& other)
{
  ptr_ = other.ptr_;
  if (ptr_) {
    other.shared_count_->add_count();
    shared_count_ = other.shared_count_;
  }
}
template <typename U>
shared_ptr(const shared_ptr<U>& other) noexcept
{
  ptr_ = other.ptr_;
  if (ptr_) {
    other.shared_count_->add_count();
    shared_count_ = other.shared_count_;
  }
}
template <typename U>
shared_ptr(shared_ptr<U>&& other) noexcept
{
  ptr_ = other.ptr_;
  if (ptr_) {
    shared_count_ = other.shared_count_;
    other.ptr_ = nullptr;
  }
}

除复制指针之外,对于拷贝构造的情况,我们需要在指针非空时把引用数加一,并复制共享计数的指针。对于移动构造的情况,我们不需要调整引用数,直接把 other.ptr_ 置为空,认为 other 不再指向该共享对象即可。

不过,上面的代码有个问题:它不能正确编译。编译器会报错,像:

fatal error: ‘ptr_’ is a private member of ‘shared_ptr

错误原因是模板的各个实例间并不天然就有 friend 关系,因而不能互访私有成员 ptr_shared_count_。我们需要在 shared_ptr 的定义中显式声明:

1
2
template <typename U>
friend class shared_ptr;

返回引用计数值

接下来,创建一个返回引用计数值的函数

1
2
3
4
5
6
7
8
9
long use_count() const
{
  if (ptr_) {
    return shared_count_
      ->get_count();
  } else {
    return 0;
  }
}

指针类型转换

对应于 C++ 里的不同的类型强制转换:

  • static_cast
  • reinterpret_cast
  • const_cast
  • dynamic_cast

智能指针需要实现类似的函数模板。实现本身并不复杂,但为了实现这些转换,我们需要添加构造函数,允许在对智能指针内部的指针对象赋值时,使用一个现有的智能指针的共享计数。如下所示:

1
2
3
4
5
6
7
8
9
template <typename U>
shared_ptr(const shared_ptr<U>& other, T* ptr)
{
    ptr_ = ptr;
    if (ptr_) {
        other.shared_count_->add_count();
        shared_count_ = other.shared_count_;
    }
}

这样我们就可以实现转换所需的函数模板了。下面实现一个 dynamic_pointer_cast 来示例一下:

1
2
3
4
5
6
template <typename T, typename U>
shared_ptr<T> dynamic_pointer_cast(const shared_ptr<U>& other)
{
  T* ptr = dynamic_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

验证

我们可以用下面的代码来验证一下它的功能正常:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <iostream>
class shape {
public:
  virtual ~shape() {}
};

class circle : public shape {
public:
  ~circle() { std::cout<<"~circle()\n"; }
};

int main()
{
  shared_ptr<circle> ptr1(new circle());
  std::cout << "use count of ptr1 is" << ptr1.use_count() << "\n";
  shared_ptr<shape> ptr2;
  std::cout << "use count of ptr2 was " << ptr2.use_count() << "\n";
  ptr2 = ptr1;
  std::cout << "use count of ptr2 is now " << ptr2.use_count() << "\n";
  if (ptr1) {
    std::cout<<"ptr1 is not empty\n";
  }

  shared_ptr<circle> ptr3 = dynamic_pointer_cast<circle>(ptr2);
  std::cout << "use count of ptr3 is " << ptr3.use_count() << "\n";
}

输出:

1
2
3
4
5
6
use count of ptr1 is1
use count of ptr2 was 0
use count of ptr2 is now 2
ptr1 is not empty
use count of ptr3 is 3
~circle()

完整代码

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include <utility>  // std::swap

class shared_count {
public:
  shared_count() noexcept : count_(1) {}
  void add_count() noexcept
  {
    ++count_;
  }
  long reduce_count() noexcept
  {
    return --count_;
  }
  long get_count() const noexcept
  {
    return count_;
  }

private:
  long count_;
};

template <typename T>
class shared_ptr {
public:
  template <typename U>
  friend class shared_ptr;

  explicit shared_ptr(T* ptr = nullptr) : ptr_(ptr)
  {
    if (ptr) {
      shared_count_ = new shared_count();
    }
  }
  ~shared_ptr()
  {
    if (ptr_ && !shared_count_->reduce_count()) {
      delete ptr_;
      delete shared_count_;
    }
  }

  shared_ptr(const shared_ptr& other)
  {
    ptr_ = other.ptr_;
    if (ptr_) {
      other.shared_count_->add_count();
      shared_count_ = other.shared_count_;
    }
  }
  template <typename U>
  shared_ptr(const shared_ptr<U>& other) noexcept
  {
    ptr_ = other.ptr_;
    if (ptr_) {
      other.shared_count_->add_count();
      shared_count_ = other.shared_count_;
    }
  }
  template <typename U>
  shared_ptr(shared_ptr<U>&& other) noexcept
  {
    ptr_ = other.ptr_;
    if (ptr_) {
      shared_count_ = other.shared_count_;
      other.ptr_ = nullptr;
    }
  }
  template <typename U>
  shared_ptr(const shared_ptr<U>& other, T* ptr) noexcept
  {
    ptr_ = ptr;
    if (ptr_) {
      other.shared_count_->add_count();
      shared_count_ = other.shared_count_;
    }
  }
  shared_ptr& operator=(shared_ptr rhs) noexcept
  {
    rhs.swap(*this);
    return *this;
  }

  T* get() const noexcept
  {
    return ptr_;
  }
  long use_count() const noexcept
  {
    if (ptr_) {
      return shared_count_->get_count();
    }
    else {
      return 0;
    }
  }
  void swap(shared_ptr& rhs) noexcept
  {
    using std::swap;
    swap(ptr_, rhs.ptr_);
    swap(shared_count_, rhs.shared_count_);
  }

  T& operator*() const noexcept
  {
    return *ptr_;
  }
  T* operator->() const noexcept
  {
    return ptr_;
  }
  operator bool() const noexcept
  {
    return ptr_;
  }

private:
  T* ptr_;
  shared_count* shared_count_;
};

template <typename T>
void swap(shared_ptr<T>& lhs, shared_ptr<T>& rhs) noexcept
{
  lhs.swap(rhs);
}

template <typename T, typename U>
shared_ptr<T> static_pointer_cast(const shared_ptr<U>& other) noexcept
{
  T* ptr = static_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

template <typename T, typename U>
shared_ptr<T> reinterpret_pointer_cast(const shared_ptr<U>& other) noexcept
{
  T* ptr = reinterpret_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

template <typename T, typename U>
shared_ptr<T> const_pointer_cast(const shared_ptr<U>& other) noexcept
{
  T* ptr = const_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

template <typename T, typename U>
shared_ptr<T> dynamic_pointer_cast(const shared_ptr<U>& other) noexcept
{
  T* ptr = dynamic_cast<T*>(other.get());
  return shared_ptr<T>(other, ptr);
}

在代码里加了不少 noexcept。这对这个智能指针在它的目标场景能正确使用是十分必要的。

总结

我们实现了一个基本完整的带引用计数的shared_ptr智能指针。从而对智能指针有一个比较深入的理解。当然,这里与标准的std::shared_ptr还欠缺一些东西,比如多线程安全、不支持自定义删除器以及和std::weak_ptr的配合。

《现代 C++编程实战》

Licensed under CC BY-NC-SA 4.0
最后更新于 Jan 11, 2024 20:05 +0800