# Save to the d2l package. classDotProductAttention(nn.Module): def__init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout)
# query: (batch_size, #queries, d) # key: (batch_size, #kv_pairs, d) # value: (batch_size, #kv_pairs, dim_v) # valid_length: either (batch_size, ) or (batch_size, xx) defforward(self, query, key, value, valid_length=None): d = query.shape[-1] # set transpose_b=True to swap the last two dimensions of key scores = torch.bmm(query, key.transpose(1,2)) / math.sqrt(d) attention_weights = self.dropout(masked_softmax(scores, valid_length)) print("attention_weight\n",attention_weights) return torch.bmm(attention_weights, value)
import zipfile import torch import requests from io import BytesIO from torch.utils import data import sys import collections
classVocab(object):# This class is saved in d2l. def__init__(self, tokens, min_freq=0, use_special_tokens=False): # sort by frequency and token counter = collections.Counter(tokens) token_freqs = sorted(counter.items(), key=lambda x: x[0]) token_freqs.sort(key=lambda x: x[1], reverse=True) if use_special_tokens: # padding, begin of sentence, end of sentence, unknown self.pad, self.bos, self.eos, self.unk = (0, 1, 2, 3) tokens = ['', '', '', ''] else: self.unk = 0 tokens = [''] tokens += [token for token, freq in token_freqs if freq >= min_freq] self.idx_to_token = [] self.token_to_idx = dict() for token in tokens: self.idx_to_token.append(token) self.token_to_idx[token] = len(self.idx_to_token) - 1 def__len__(self): return len(self.idx_to_token) def__getitem__(self, tokens): ifnot isinstance(tokens, (list, tuple)): return self.token_to_idx.get(tokens, self.unk) else: return [self.__getitem__(token) for token in tokens] defto_tokens(self, indices): ifnot isinstance(indices, (list, tuple)): return self.idx_to_token[indices] else: return [self.idx_to_token[index] for index in indices]
defload_data_nmt(batch_size, max_len, num_examples=1000): """Download an NMT dataset, return its vocabulary and data iterator.""" # Download and preprocess defpreprocess_raw(text): text = text.replace('\u202f', ' ').replace('\xa0', ' ') out = '' for i, char in enumerate(text.lower()): if char in (',', '!', '.') and text[i-1] != ' ': out += ' ' out += char return out
with open('/home/kesci/input/fraeng6506/fra.txt', 'r') as f: raw_text = f.read()
text = preprocess_raw(raw_text)
# Tokenize source, target = [], [] for i, line in enumerate(text.split('\n')): if i >= num_examples: break parts = line.split('\t') if len(parts) >= 2: source.append(parts[0].split(' ')) target.append(parts[1].split(' '))
# Build vocab defbuild_vocab(tokens): tokens = [token for line in tokens for token in line] return Vocab(tokens, min_freq=3, use_special_tokens=True) src_vocab, tgt_vocab = build_vocab(source), build_vocab(target)
# Convert to index arrays defpad(line, max_len, padding_token): if len(line) > max_len: return line[:max_len] return line + [padding_token] * (max_len - len(line))
defbuild_array(lines, vocab, max_len, is_source): lines = [vocab[line] for line in lines] ifnot is_source: lines = [[vocab.bos] + line + [vocab.eos] for line in lines] array = torch.tensor([pad(line, max_len, vocab.pad) for line in lines]) valid_len = (array != vocab.pad).sum(1) return array, valid_len