
http://www.ma-xy.com
1.6 对抗生成网络 GAN 第一章 深度学习
13 de f x a v i e r_init ( s i z e ) :
14 in_dim = s i z e [ 0 ]
15 xavier_stddev = 1 . / t f . s q r t (in_dim / 2 . )
16 return t f . random_normal( shape=si z e , stddev=xavier_stddev )
17 ””” Disc rim in ato r Net model ”””
18 X = t f . p laceh o l d e r ( t f . f l o a t 3 2 , shape=[None , 784])
19 y = t f . p l a c ehold e r ( t f . float 3 2 , shape=[None , y_dim ] )
20 D_W1 = t f . Va riable ( x a v i e r _ i n i t ( [ X_dim + y_dim, h_dim ] ) )
21 D_b1 = t f . Variable ( t f . z e r o s ( shape=[h_dim] ) )
22 D_W2 = t f . Va riable ( x a v i e r _ i n i t ( [ h_dim , 1 ] ) )
23 D_b2 = t f . Variable ( t f . z e r o s ( shape = [1] ) )
24 theta_D = [D_W1, D_W2, D_b1, D_b2]
25 de f d i s c r i m in a t or (x , y ) :
26 i nputs = t f . concat ( a xi s =1, va lue s =[x , y ] )
27 D_h1 = t f . nn . re l u ( t f . matmul( inputs , D_W1) + D_b1)
28 D_logit = t f . matmul(D_h1, D_W2) + D_b2
29 D_prob = t f . nn . sigmoid ( D_logit )
30 return D_prob , D_logit
31 ””” Generator Net model ”””
32 Z = t f . plac e h o l der ( t f . f l o a t 3 2 , shape=[None , Z_dim ] )
33 G_W1 = t f . Variable ( xavi e r _ i n i t ( [ Z_dim + y_dim , h_dim ] ) )
34 G_b1 = t f . Variable ( t f . z e r o s ( shape=[h_dim ] ) )
35 G_W2 = t f . Variable ( xavi e r _ i n i t ( [ h_dim, X_dim] ) )
36 G_b2 = t f . Variable ( t f . z e r o s ( shape=[X_dim] ) )
37 theta_G = [G_W1, G_W2, G_b1, G_b2]
38 de f ge ner ator ( z , y) :
39 i nputs = t f . concat ( a xi s =1, va lue s =[z , y ] )
40 G_h1 = t f . nn . r e l u ( t f . matmul( inputs , G_W1) + G_b1)
41 G_log_prob = t f . matmul(G_h1, G_W2) + G_b2
42 G_prob = t f . nn . sigmoid (G_log_prob)
43 return G_prob
44 de f sample_Z (m, n) :
45 return np . random . uniform ( −1. , 1 . , s i z e =[m, n ] )
46 de f p l o t ( samples ) :
47 f i g = pl t . f i g u r e ( f i g s i z e =(4 , 4) )
48 gs = g r i d s p ec . GridSpec ( 4 , 4)
49 gs . update ( wspace =0.05 , hspace =0.05)
50 f o r i , sample in enumerate ( samples ) :
51 ax = p l t . subp lot ( gs [ i ] )
52 p l t . axis ( ’ o f f ’ )
53 ax . s et _ x t ick l a b el s ( [ ] )
54 ax . s et _ y t ick l a b el s ( [ ] )
55 ax . set_aspect ( ’ equal ’ )
56 p l t . imshow( sample . reshape (2 8 , 28) , cmap=’ Greys_r ’ )
57 return f i g
58 G_sample = gen era to r (Z , y)
59 D_real , D_logit_real = d i s c r i m i n a to r (X, y )
60 D_fake , D_logit_fake = d i s c r i m i n at o r (G_sample , y)
61 D_loss_real = t f . reduce_mean( t f . nn . sigmoid_cross_entropy_with_logits ( l o g i t s=D_logit_real ,
l a b e l s=t f . one s_ like ( D_logit_real ) ) )
62 D_loss_fake = t f . reduce_mean ( t f . nn . sigmoid_cross_entropy_with_logits ( l o g i t s=D_logit_fake ,
l a b e l s=t f . z e r o s _ l i k e ( D_logit_fake ) ) )
63 D_loss = D_loss_real + D_loss_fake
http://www.ma-xy.com 92 http://www.ma-xy.com