热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

类型错误:线性():参数“输入”(位置1)必须是张量,而不是str

所以我一直在尝试研究我在github上发现的一些bert示例,这是我第一次尝试使用bert并查看它是如何工作的。使用的呼吸即时消息如下:https://github.com

所以我一直在尝试研究我在 github 上发现的一些 bert 示例,这是我第一次尝试使用 bert 并查看它是如何工作的。使用的呼吸即时消息如下:https : //github.com/prateekjoshi565/Fine-Tuning-BERT/blob/master/Fine_Tuning_BERT_for_Spam_Classification.ipynb

我使用了不同的数据集,但是我遇到了问题 TypeError: linear(): argument 'input' (position 1) must be Tensor, not str" 老实说,我不知道我做错了什么。有没有人可以帮助我?

我一直在使用的代码如下:

# convert class weights to tensor
weights= torch.tensor(class_wts,dtype=torch.float)
weights = weights.to(device)
# loss function
cross_entropy = nn.NLLLoss(weight=weights)
# number of training epochs
epochs = 10
def train():
model.train()
total_loss, total_accuracy = 0, 0
# empty list to save model predictions
total_preds=[]
# iterate over batches
for step,batch in enumerate(train_dataloader): # progress update after every 50 batches.
if step % 50 == 0 and not step == 0:
print(' Batch {:>5,} of {:>5,}.'.format(step, len(train_dataloader)))
# push the batch to gpu
batch = [r.to(device) for r in batch]
sent_id, mask, labels = batch
# clear previously calculated gradients
model.zero_grad()
# get model predictions for the current batch
preds = model(sent_id, mask)
# compute the loss between actual and predicted values
loss = cross_entropy(preds, labels)
# add on to the total loss
total_loss = total_loss + loss.item()
# backward pass to calculate the gradients
loss.backward()
# clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# update parameters
optimizer.step()
# model predictions are stored on GPU. So, push it to CPU
preds=preds.detach().cpu().numpy()
# append the model predictions
total_preds.append(preds)
# compute the training loss of the epoch
avg_loss = total_loss / len(train_dataloader)
# predictions are in the form of (no. of batches, size of batch, no. of classes).
# reshape the predictions in form of (number of samples, no. of classes)
total_preds = np.concatenate(total_preds, axis=0)
#returns the loss and predictions
return avg_loss, total_preds
def evaluate():
print("nEvaluating...")
# deactivate dropout layers
model.eval()
total_loss, total_accuracy = 0, 0
# empty list to save the model predictions
total_preds = []
# iterate over batches
for step,batch in enumerate(val_dataloader): # Progress update every 50 batches.
if step % 50 == 0 and not step == 0:
# Calculate elapsed time in minutes.
elapsed = format_time(time.time() - t0)

# Report progress.
print(' Batch {:>5,} of {:>5,}.'.format(step, len(val_dataloader)))
# push the batch to gpu
batch = [t.to(device) for t in batch]
sent_id, mask, labels = batch
# deactivate autograd
with torch.no_grad():
# model predictions
preds = model(sent_id, mask)
# compute the validation loss between actual and predicted values
loss = cross_entropy(preds,labels)
total_loss = total_loss + loss.item()
preds = preds.detach().cpu().numpy()
total_preds.append(preds)
# compute the validation loss of the epoch
avg_loss = total_loss / len(val_dataloader)
# reshape the predictions in form of (number of samples, no. of classes)
total_preds = np.concatenate(total_preds, axis=0)
return avg_loss, total_preds
# set initial loss to infinite
best_valid_loss = float('inf')
# empty lists to store training and validation loss of each epoch
train_losses=[]
valid_losses=[]
#for each epoch
for epoch in range(epochs):

print('n Epoch {:} / {:}'.format(epoch + 1, epochs)) #train model
train_loss, _ = train() #evaluate model
valid_loss, _ = evaluate() #save the best model
if valid_loss best_valid_loss = valid_loss
torch.save(model.state_dict(), 'saved_weights.pt') # append training and validation loss
train_losses.append(train_loss)
valid_losses.append(valid_loss) print(f'nTraining Loss: {train_loss:.3f}')
print(f'Validation Loss: {valid_loss:.3f}')

我收到的回溯是:

Epoch 1 / 10
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
in ()
12
13 #train model
---> 14 train_loss, _ = train()
15
16 #evaluate model
5 frames
in train()
24
25 # get model predictions for the current batch
---> 26 preds = model(sent_id, mask)
27
28 # compute the loss between actual and predicted values
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
in forward(self, sent_id, mask)
28 _, cls_hs = self.bert(sent_id, attention_mask=mask)
29
---> 30 x = self.fc1(cls_hs)
31
32 x = self.relu(x)
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py in forward(self, input)
92
93 def forward(self, input: Tensor) -> Tensor:
---> 94 return F.linear(input, self.weight, self.bias)
95
96 def extra_repr(self) -> str:
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
1751 if has_torch_function_variadic(input, weight):
1752 return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1753 return torch._C._nn.linear(input, weight, bias)
1754
1755
TypeError: linear(): argument 'input' (position 1) must be Tensor, not str

回答

我也一直在研究这个 repo。受到此链接上提供的答案的启发。有一个可能名为 Bert_Arch 的类继承了 nn.Module,这个类有一个名为 forward 的重写方法。在 forward 方法中,只需将参数 'return_dict=False' 添加到 self.bert() 方法调用中。像这样:

_, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)

这对我有用。






推荐阅读
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 提升Python编程效率的十点建议
    本文介绍了提升Python编程效率的十点建议,包括不使用分号、选择合适的代码编辑器、遵循Python代码规范等。这些建议可以帮助开发者节省时间,提高编程效率。同时,还提供了相关参考链接供读者深入学习。 ... [详细]
  • 本文介绍了在开发Android新闻App时,搭建本地服务器的步骤。通过使用XAMPP软件,可以一键式搭建起开发环境,包括Apache、MySQL、PHP、PERL。在本地服务器上新建数据库和表,并设置相应的属性。最后,给出了创建new表的SQL语句。这个教程适合初学者参考。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • 本文分享了一个关于在C#中使用异步代码的问题,作者在控制台中运行时代码正常工作,但在Windows窗体中却无法正常工作。作者尝试搜索局域网上的主机,但在窗体中计数器没有减少。文章提供了相关的代码和解决思路。 ... [详细]
  • 开发笔记:加密&json&StringIO模块&BytesIO模块
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了加密&json&StringIO模块&BytesIO模块相关的知识,希望对你有一定的参考价值。一、加密加密 ... [详细]
  • 本文介绍了九度OnlineJudge中的1002题目“Grading”的解决方法。该题目要求设计一个公平的评分过程,将每个考题分配给3个独立的专家,如果他们的评分不一致,则需要请一位裁判做出最终决定。文章详细描述了评分规则,并给出了解决该问题的程序。 ... [详细]
  • 本文介绍了计算机网络的定义和通信流程,包括客户端编译文件、二进制转换、三层路由设备等。同时,还介绍了计算机网络中常用的关键词,如MAC地址和IP地址。 ... [详细]
  • 拥抱Android Design Support Library新变化(导航视图、悬浮ActionBar)
    转载请注明明桑AndroidAndroid5.0Loollipop作为Android最重要的版本之一,为我们带来了全新的界面风格和设计语言。看起来很受欢迎࿰ ... [详细]
  • CF:3D City Model(小思维)问题解析和代码实现
    本文通过解析CF:3D City Model问题,介绍了问题的背景和要求,并给出了相应的代码实现。该问题涉及到在一个矩形的网格上建造城市的情景,每个网格单元可以作为建筑的基础,建筑由多个立方体叠加而成。文章详细讲解了问题的解决思路,并给出了相应的代码实现供读者参考。 ... [详细]
  • FeatureRequestIsyourfeaturerequestrelatedtoaproblem?Please ... [详细]
  • 本文讨论了在openwrt-17.01版本中,mt7628设备上初始化启动时eth0的mac地址总是随机生成的问题。每次随机生成的eth0的mac地址都会写到/sys/class/net/eth0/address目录下,而openwrt-17.01原版的SDK会根据随机生成的eth0的mac地址再生成eth0.1、eth0.2等,生成后的mac地址会保存在/etc/config/network下。 ... [详细]
  • 本文介绍了Python爬虫技术基础篇面向对象高级编程(中)中的多重继承概念。通过继承,子类可以扩展父类的功能。文章以动物类层次的设计为例,讨论了按照不同分类方式设计类层次的复杂性和多重继承的优势。最后给出了哺乳动物和鸟类的设计示例,以及能跑、能飞、宠物类和非宠物类的增加对类数量的影响。 ... [详细]
  • 模板引擎StringTemplate的使用方法和特点
    本文介绍了模板引擎StringTemplate的使用方法和特点,包括强制Model和View的分离、Lazy-Evaluation、Recursive enable等。同时,还介绍了StringTemplate语法中的属性和普通字符的使用方法,并提供了向模板填充属性的示例代码。 ... [详细]
  • 本文介绍了在满足特定条件时如何在输入字段中使用默认值的方法和相应的代码。当输入字段填充100或更多的金额时,使用50作为默认值;当输入字段填充有-20或更多(负数)时,使用-10作为默认值。文章还提供了相关的JavaScript和Jquery代码,用于动态地根据条件使用默认值。 ... [详细]
author-avatar
黄体测字_335
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有