Python传参传的到底是值还是地址?

Python经验总结

Posted by Tianhao Alex Huang on 2023-03-19
Estimated Reading Time 6 Minutes
Words 1.5k In Total
Viewed Times

起因

在看一篇论文的代码,发现代码里是这么写的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
trainer = FlashbackTrainer(setting.lambda_t, setting.lambda_s, setting.lambda_loc, setting.lambda_user,
setting.use_weight, transition_graph, spatial_graph, friend_graph, setting.use_graph_user,
setting.use_spatial_graph, interact_graph) # 0.01, 100 or 1000
h0_strategy = create_h0_strategy(setting.hidden_dim, setting.is_lstm) # 10 True or False
trainer.prepare(poi_loader.locations(), poi_loader.user_count(), setting.hidden_dim, setting.rnn_factory,
setting.device)
evaluation_test = Evaluation(dataset_test, dataloader_test, poi_loader.user_count(), h0_strategy, trainer, setting, log)
print('{} {}'.format(trainer, setting.rnn_factory))

# training loop
optimizer = torch.optim.Adam(trainer.parameters(), lr=setting.learning_rate, weight_decay=setting.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=0.2)

bar = tqdm(total=setting.epochs)
bar.set_description('Training')

for e in range(setting.epochs): # 100
h = h0_strategy.on_init(setting.batch_size, setting.device)
dataset.shuffle_users() # shuffle users before each epoch!

losses = []
epoch_start = time.time()
for i, (x, t, t_slot, s, y, y_t, y_t_slot, y_s, reset_h, active_users) in enumerate(dataloader):
# reset hidden states for newly added users
for j, reset in enumerate(reset_h):
if reset:
if setting.is_lstm:
hc = h0_strategy.on_reset(active_users[0][j])
h[0][0, j] = hc[0]
h[1][0, j] = hc[1]
else:
h[0, j] = h0_strategy.on_reset(active_users[0][j])

x = x.squeeze().to(setting.device)
t = t.squeeze().to(setting.device)
t_slot = t_slot.squeeze().to(setting.device)
s = s.squeeze().to(setting.device)

y = y.squeeze().to(setting.device)
y_t = y_t.squeeze().to(setting.device)
y_t_slot = y_t_slot.squeeze().to(setting.device)
y_s = y_s.squeeze().to(setting.device)
active_users = active_users.to(setting.device)

optimizer.zero_grad()
forward_start = time.time()
loss = trainer.loss(x, t, t_slot, s, y, y_t, y_t_slot, y_s, h, active_users)

# print('One forward: ', time.time() - forward_start)

start = time.time()
loss.backward(retain_graph=True)

# torch.nn.utils.clip_grad_norm_(trainer.parameters(), 5)
end = time.time()
# print('反向传播需要{}s'.format(end - start))
losses.append(loss.item())
optimizer.step()

# schedule learning rate:
scheduler.step()
bar.update(1)
epoch_end = time.time()
log_string(log, 'One training need {:.2f}s'.format(
epoch_end - epoch_start))
# statistics:
if (e + 1) % 1 == 0:
epoch_loss = np.mean(losses)
# print(f'Epoch: {e + 1}/{setting.epochs}')
# print(f'Used learning rate: {scheduler.get_last_lr()[0]}')
# print(f'Avg Loss: {epoch_loss}')
log_string(log, f'Epoch: {e + 1}/{setting.epochs}')
log_string(log, f'Used learning rate: {scheduler.get_last_lr()[0]}')
log_string(log, f'Avg Loss: {epoch_loss}')

# if (e + 1) >= 21: # 第25轮效果最好, 直接评估这一轮 (e + 1) % setting.validate_epoch == 0 or
# if (e + 1) == 23 or (e + 1) == 43:
if (e + 1) % setting.validate_epoch == 0:
log_string(log, f'~~~ Test Set Evaluation (Epoch: {e + 1}) ~~~')
print(f'~~~ Test Set Evaluation (Epoch: {e + 1}) ~~~')
evl_start = time.time()
evaluation_test.evaluate(e)
evl_end = time.time()
# print('评估需要{:.2f}'.format(evl_end - evl_start))
log_string(log, 'One evaluate need {:.2f}s'.format(evl_end - evl_start))

bar.close()

重点来看这几行:

1
2
3
4
5
6
trainer = FlashbackTrainer(setting.lambda_t, setting.lambda_s, setting.lambda_loc, setting.lambda_user,
setting.use_weight, transition_graph, spatial_graph, friend_graph, setting.use_graph_user,
setting.use_spatial_graph, interact_graph) # 0.01, 100 or 1000
h0_strategy = create_h0_strategy(setting.hidden_dim, setting.is_lstm) # 10 True or False
trainer.prepare(poi_loader.locations(), poi_loader.user_count(), setting.hidden_dim, setting.rnn_factory, setting.device)
evaluation_test = Evaluation(dataset_test, dataloader_test, poi_loader.user_count(), h0_strategy, trainer, setting, log)

这里evaluation_test在实例化Evaluation类时传入了trainer这个参数,而trainer是FlashbackTrainer类。此时,trainer还未训练,参数全部都是初始化的。

而在后面的循环之中才开始训练trainer,这里就让我产生了一个疑惑,这里的参数传进去的到底是值还是地址呢?我猜测是地址,如果是值的话,每一轮的evaluate将没有意义,因为他一直在evaluate最开始的参数。

解决

所以我写了一个代码来验证这一猜想:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class A(object):
def __init__(self):
self.a = 0
self.b = 0

def change(self):
self.b += 1
self.a -= 1


class B(object):
def __init__(self, A):
self.A = A

def print_A(self):
print("A.a: ", self.A.a)
print("A.b: ", self.A.b)


if __name__ == '__main__':
a = A()
b = B(a)
b.print_A()
a.change()
b.print_A()

结果不出我所料:

1
2
3
4
A.a:  0
A.b: 0
A.a: -1
A.b: 1

证明将类作为传入其实传的是地址而不是类的值。

更进一步

这一问题引发了我对Python传参机制的好奇,Google之,发现一篇高质量blog:深度好文! Python函数参数传递:到底是值传递还是传用传递?

Python的基本数据类型?

首先我们要知道Python有6大基本数据类型:

  • 数值类型: int, float, bool, complex
  • 字符串: str, 元字符串(即r+str, 不转义字符)
  • 元组: tuple
  • 列表: list
  • 字典: dict
  • 集合: set

可变数据类型 or 不可变数据类型?

在这6种数据类型中,可变数据类型有:

  • 列表: list
  • 字典: dict
  • 集合: set

不可变数据类型有:

  • 数值类型: int, float, bool, complex
  • 字符串: str, 元字符串(即r+str, 不转义字符)
  • 元组: tuple

值传递 or 引用传递?

  • 值传递(Pass-By-Value)
    被调函数的形参作为被调函数的局部变量来处理,简单来说就是在栈中新开开辟了内存空间来存放主调函数传过来的实参的。此时函数中的形参实际为主调函数实参的副本。因此,被调函数中对形参的操作并不会影响外部实参的值。

  • 引用传递(Pass-By-Reference)
    被调函数的形参同样作为被调函数的局部变量来处理,但此时在栈中开辟的内存空间存放的是主调函数传过来的实参的地址。此时,被调函数对形参的操作实际为间接寻址,即通过堆栈中存放的地址来访问主调函数中的实参变量。因此,被调函数中对形参的操作影响外部实参的值。

那么,什么时候是值传递,什么时候是引用传递呢?

Python解释器会查看对象引用(内存地址)指示的那个值的类型,如果变量指示一个可变的值,就会按引用调用语义。如果所指示的数据的类型是不可变的,则会应用按值调用语义。

因此,对于可变数据:

  • 列表: list
  • 字典: dict
  • 集合: set

总是会按引用传入函数,函数代码组中对变量数据结构的任何改变都会反映到调用代码中。

而对于可变不数据:

  • 数值类型: int, float, bool, complex
  • 字符串: str, 元字符串(即r+str, 不转义字符)
  • 元组: tuple

总是会按值传入函数,函数中对变量的任何修改是这个函数私有的,不会反映到调用代码中。


如果您喜欢此博客或发现它对您有用,则欢迎对此发表评论。 也欢迎您共享此博客,以便更多人可以参与。 如果博客中使用的图像侵犯了您的版权,请与作者联系以将其删除。 谢谢 !