tinyllm-week1

Day 1: 注意力和多头注意力

Task1: 实现scaled_dot_product attention_simple

实现代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def scaled_dot_product_attention_simple(
query: mx.array,
key: mx.array,
value: mx.array,
scale: float | None = None,
mask: mx.array | None = None,
) -> mx.array:
# pass
if scale is None:
d_k = query.shape[-1]
scores = mx.matmul(query, key.swapaxes(-2, -1)) / mx.sqrt(d_k)
else:
scores = mx.matmul (query, key.swapaxes(-2, -1)) * scale
if mask is not None:
scores += mask
return mx.matmul(softmax(scores, axis=-1), value)

测试结果为:task1测试结果

Task2: 实现SimpleMultiHeadAttention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class SimpleMultiHeadAttention:
def __init__(
self,
hidden_size: int,
num_heads: int,
wq: mx.array,
wk: mx.array,
wv: mx.array,
wo: mx.array,
):
# 输入需要拆分成 head 矩阵,隐藏层维度必须是 head 数的整数倍
assert hidden_size % num_heads == 0
self.num_heads = num_heads
self.hidden_size = hidden_size
# 向下取整
self.head_dim = hidden_size // num_heads
self.scale = None
assert wq.shape == (hidden_size, num_heads * self.head_dim)
assert wk.shape == (hidden_size, num_heads * self.head_dim)
assert wv.shape == (hidden_size, num_heads * self.head_dim)
assert wo.shape == (num_heads * self.head_dim, hidden_size)

self.wq = wq
self.wk = wk
self.wv = wv
self.wo = wo

def __call__(
self,
query: mx.array,
key: mx.array,
value: mx.array,
mask: mx.array | None = None,
) -> mx.array:

assert query.shape == key.shape == value.shape
# N * L * D
N, L, _ = query.shape
xq = linear(query, self.wq).reshape(N, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
xk = linear(key, self.wk).reshape(N, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
xv = linear(value, self.wv).reshape(N, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)

x = scaled_dot_product_attention_simple(
xq, xk, xv, scale= self.scale, mask=mask
)
x = x.transpose(0, 2, 1, 3).reshape(N, L, self.hidden_size)
return linear(x, self.wo)

测试结果为:task2测试结果