1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| class SimpleMultiHeadAttention: def __init__( self, hidden_size: int, num_heads: int, wq: mx.array, wk: mx.array, wv: mx.array, wo: mx.array, ): assert hidden_size % num_heads == 0 self.num_heads = num_heads self.hidden_size = hidden_size self.head_dim = hidden_size // num_heads self.scale = None assert wq.shape == (hidden_size, num_heads * self.head_dim) assert wk.shape == (hidden_size, num_heads * self.head_dim) assert wv.shape == (hidden_size, num_heads * self.head_dim) assert wo.shape == (num_heads * self.head_dim, hidden_size)
self.wq = wq self.wk = wk self.wv = wv self.wo = wo
def __call__( self, query: mx.array, key: mx.array, value: mx.array, mask: mx.array | None = None, ) -> mx.array: assert query.shape == key.shape == value.shape N, L, _ = query.shape xq = linear(query, self.wq).reshape(N, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) xk = linear(key, self.wk).reshape(N, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) xv = linear(value, self.wv).reshape(N, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) x = scaled_dot_product_attention_simple( xq, xk, xv, scale= self.scale, mask=mask ) x = x.transpose(0, 2, 1, 3).reshape(N, L, self.hidden_size) return linear(x, self.wo)
|