200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > tensor二维矩阵计算相似度

tensor二维矩阵计算相似度

时间:2023-11-01 19:26:49

相关推荐

tensor二维矩阵计算相似度

注意:计算相似度时必须保证两个矩阵维度相同,否则报错

import torchfrom transformers import BertConfig, BertModel, BertTokenizerdef bert_output(texts, name):tokens, segments, input_masks =[], [], []for text in texts:tokenized_text = tokenizer.tokenize(text)indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)tokens.append(indexed_tokens)segments.append( [0]*len(indexed_tokens) )input_masks.append( [1]*len(indexed_tokens) )max_len = max([len(single) for single in tokens]) # 最大的句子长度for j in range(len(tokens)):padding = [0] * (max_len - len(tokens[j]))tokens[j] += paddingsegments[j] += paddinginput_masks[j] += padding# device = torch.cuda.current_device()tokens_tensor = torch.tensor(tokens)segments_tensors = torch.tensor(segments)input_masks_tensors = torch.tensor(input_masks)# output = model(tokens_tensor)output = model(tokens_tensor, segments_tensors, input_masks_tensors)sequence_output = output[0]pooled_output = output[1] # CLStorch.set_printoptions(edgeitems=30)# with open(name, 'a', encoding='utf-8') as f:## f.write("sequence_output:")#f.write(str(sequence_output))## f.write('\n')## f.write("pooled_output:")## f.write(str(pooled_output))return sequence_output#输出slotif __name__ == '__main__':tokenizer = BertTokenizer.from_pretrained('./bert-base-uncased')model_config = BertConfig.from_pretrained('./bert-base-uncased')model = BertModel.from_pretrained('./bert-base-uncased',config=model_config)# texts_atis = ["[CLS] i want to fly from baltimore to dallas round trip [SEP]",#"[CLS] how can i find that out [SEP]",#"[CLS] how many flights does twa have in business class [SEP]"]texts_atis = ["[CLS] i want to fly from baltimore to dallas round trip [SEP]"]texts_snips = ["[CLS] what the weather in my current spot the [SEP]","[CLS] what the weather like in the city frewen [SEP]","[CLS] what the weather supposed to be like today [SEP]"]#整个文件atis = 'atis.txt'snips = 'snips.txt'atis_out = bert_output(texts_atis, atis)#bert输出向量atissnips_out = bert_output(texts_snips, snips)#bert输出向量snipsfor text in atis_out:#text是二重矩阵atis_2 = text# print(list(atis_2.size()))for text in snips_out:#text是二重矩阵output = torch.cosine_similarity(atis_2, text, dim=1)print(output)# print(list(text.size()))

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。