What s Canddate Samplng Say we have a multclass or mult label problem where each tranng example ( x, T ) conssts of a context x a small (mult)set of target classes T out of a large unverse L of possble classes. For example, the problem mght be to predctng the next word (or the set of future words) n a sentence gven the prevous words. We wsh to learn a compatblty functon F (x, y ) whch says somethng about the compatblty of a class y wth a context x. For example the probablty of the class gven the context. Exhaustve tranng methods such as softmax and logstc regresson requre us to compute F (x, y ) for every class y L for every tranng example. When L s very large, ths can be prohbtvely expensve. Canddate Samplng tranng methods nvolve constructng a tranng task n whch for each tranng example ( x, T ), we only need to evaluate F (x, y ) for a small set of canddate classes C L. Typcally, the set of canddates C s the unon of the target classes wth a randomly chosen sample of (other) classes S L. C = T S The random choce of S may or may not depend on x and/or T. The tranng algorthm takes the form of a neural network, where the layer representng (x, ) F y s traned by back propagaton from a loss functon.
Table of Canddate Samplng Algorthms Postve tranng classes assocated wth tranng example ( x, T ) : Negatve tranng classes assocated wth tranng example ( x, T ) : Input to Tranng Loss G(x, y ) = Tranng Loss F (x, y ) gets traned to approxmate: P OS = NEG = Nose Contrastve Estmaton (NCE) T S F (x, y) l og(q(y x)) Logstc l og(p (y x)) Negatve Samplng T S F (x, y ) Logstc l og P ( Q(y x) (y x)) Sampled Logstc T ( S T ) F (x, y) l og(q(y x)) Logstc l ogodds(y x) = l og ( P 1 P (y x)) Full Logstc T ( L T ) F (x, y ) Logstc l og(odds(y x)) = l og ( P 1 P (y x)) Full Softmax T = { t } ( L T ) F (x, y ) Softmax l og(p (y x)) + K (x) Sampled Softmax T = { t } ( S T ) F (x, y) l og(q(y x)) Softmax l og(p (y x)) + K (x) Q (y x) s defned as the probablty (or expected count) accordng to the samplng algorthm of the class y n the (mult )set of sampled classes gven the context x. K (x) s an arbtrary functon that does not depend on the canddate class. Snce Softmax nvolves a normalzaton, addton of such a functon does not affect the computed probabltes. ( ) l ogstc tranng loss = l og(1 + exp( G(x, y )) + l og(1 + exp(g(x, y )) y P OS y NEG ( ( )) s oftmax tranng loss = G (x, t ) + log exp(g(x, y )) y P OS NEG NCE and Negatve Samplng generalze to the case where T s a multset. In ths case, P (y x) denotes the expected count of y n T. Smlarly, NCE, Negatve Samplng, and Sampled Logstc generalze to the case where S s a multset. In ths case Q(y x) denotes the expected count of y n S.
Sampled Softmax (A faster way to tran a softmax classfer) Reference: http://arxv.org/abs/1412.2007 Assume that we have a sngle label problem. Each tranng example ( x, {t }) conssts of a context and one target class. We wrte P (y x) for the probablty of that the one target class s y gven that the context s x. We would lke to tran a functon F (x, y ) to produce softmax logts that s, relatve log probabltes of the class gven the context: F (x, y) log(p (y x)) + K (x) Where K(x) s an arbtrary functon that does not depend on y. In full softmax tranng, for every tranng example ( x, {t }), we would need to compute logts F (x, y ) for all classes n y L. Ths can get expensve f the unverse of classes L s very large. In Sampled Softmax, for each tranng example ( x, { t }), we pck a small set S L of sampled classes accordng to a chosen samplng functon Q (y x). Each class y L s ncluded n S ndependently wth probablty Q(y x ). P (S = S x ) = Q(y x ) (1 Q(y x )) y S y (L S) We create a set of canddates classes: C contanng the unon of the target class and the sampled t } C = S { Our tranng task s to fgure out, gven ths set C, whch of the classes n C s the target class. For each class y C, we want to compute the posteror probablty that y s the target class gven our knowledge of x and C. We call ths P (t = y x, C ) Applyng Bayes rule: (t x, ) (t, x ) / P (C x ) P = y C = P = y C (t x ) P (C t, ) / P (C x ) = P = y = y x P (y x ) P (C t, x ) / P (C x ) = = y
Now to compute P (C t = y, x ), we note that n order for ths to happen, S may or may not contan y, must contan all other elements of C, and must not contan any classes not n C. So: P (t = y x, C ) = P (y x ) Q (y x ) (1 Q (y x )) / P (C x ) P (y x = ) Q(y x ) y C P (y x = ) Q(y x ) C y C {y} y (L C ) Q (y x ) (1 Q (y x )) / P (C x ) / K(x, ) y (L C ) where K(x, C ) s a functon that does not depend on y. So: log(p (t = y x, C )) = log(p (y x )) log(q(y x )) + K (x, C ) These are the relatve logts that should feed nto a softmax classfer predctng whch of the canddates n s the true one. C Snce we are tryng to tran the functon F (x, y) to approxmate l og(p (y x)), we take the layer n our network representng F (x, y), subtract log(q(y x)), and pass the result to a softmax classfer predctng whch canddate s the true one. T ranng Sof tmax Input = F (x, y) l og(q(y x) Backpropagatng the gradents from that classfer trans F to gve us what we want.
Nose Contrastve Estmaton (NCE) Reference: http://www.jmlr.org/proceedngs/papers/v9/gutmann10a/gutmann10a.pdf Each tranng example ( x, T ) conssts of a context and a small multset of target classes. In practce, T x may always be a set or even a sngle class, but we use a multset here for generalty. We use the followng as a shorthand for the expected count of a class n the set of target classes for a context. In the case of sets wth no duplcates, ths s the probablty of the class gven the context: P (y x) : = E(T (y) x) We would lke to tran a functon F (x, y ) to approxmate the log expected count of the class gven the context, or n the case of a sets, the log probablty of the class gven the context. F (x, y) log (P (y x)) For each example ( x, T ), we pck a multset of sampled classes S. In practce, t probably makes sense to pck a set, but we use a multset here for generalty. Our samplng algorthm may or may not depend on x but may not depend on T. We construct a multset of canddates consstng of the sum of the target classes and the sampled classes. C = T + S Our tranng task s to dstngush the true canddates from the sampled canddates. We have one postve tranng meta example for each element of and one negatve tranng meta example for each element of S. We ntroduce the shorthand Q (y x) to denote the expected count, accordng to our samplng algorthm, of a partcular class n the set of sampled classes. If S never contans duplcates, then ths s a probablty. P l ogodds(y came from T vs S x) = l og ( Q(y x) (y x)) = l og (P (y x)) l og(q(y x)) T Q (y x) : = E (S(y) x)) The frst term, l og (P (y x)), s what we would lke to tran F (x, y ) to estmate.
We have a layer n our model whch represents F (x, y ). We add to t the second term, l og(q(y x)), whch we compute analytcally, and we pass the result to a logstc regresson loss whose label ndcates whether y came from T as opposed to S. L ogstc Regresson Input = F (x, y) log(q(y x)) The backpropagaton sgnal trans F (x, y ) to approxmate what we want t to.
Negatve Samplng Reference: http://papers.nps.cc/paper/5021 dstrbuted representatons of words and phrases and ther co mpostonalty.pdf Negatve samplng s a smplfed varant of Nose Contrastve Estmaton where we neglect to subtract off l og(q(y x)) durng tranng. As a result, F (x, y ) s traned to approxmate l og (E(y x)) l og(q(y x)). It s noteworthy that n Negatve Samplng, we are optmzng F (x, y ) to approxmate somethng that depends on the samplng dstrbuton Q. Ths wll make the results hghly dependent on the choce of samplng dstrbuton. Ths s not true for the other algorthms descrbed here.
Sampled Logstc Sampled Logstc s a varant on Nose Contrastve Estmaton where we dscard wthout replacement all sampled classes that happen to also be target classes. Ths requres a set, as opposed to a multset, though S T to be may be a multset. As a result we learn an estmator of the log odds of a class as opposed to the log probablty of a class. The math changes from the NCE math as follows: P (y x) l ogodds(y came from T vs (S T ) x) = l og( Q(y x)(1 P (y x)) = l og( P 1 P (y x)) l og(q(y x)) ( P (y x) The frst term, l og 1 P (y x)), s what we would lke to tran F (x, y ) to estmate. We have a layer n our model, whch represents F (x, y ). We add to t the second term, l og(q(y x)), whch we compute analytcally, and we pass the result to a logstc regresson loss predctng whether y came from T vs ( S T ). The backpropagaton sgnal trans the L ogstc Regresson Input = F (x, y) log(q(y x) F (x, y ) layer to approxmate what we want t to. F (x, y) log ( P (y x) 1 P (y x))
Context Specfc vs. Generc Samplng In the methods dscussed, the samplng algorthm s allowed to depend on the context. It s possble that for some models, context specfc samplng wll be very useful, n that we can generate context dependent hard negatves and provde a more useful tranng sgnal. The authors have to ths pont focused on generc samplng algorthms such as unform samplng and ungram samplng, whch do not make use of the context. The reason s descrbed n the next secton. Batchwse Samplng We have focused on models whch use the same set S of sampled classes across a whole batch of tranng examples. Ths seems counterntutve shouldn t convergence be faster f we use dfferent sampled classes for each tranng example? The reason for usng the same sampled classes across a batch s computatonal. In many of our models, F (x, y ) s computed as the dot product of a feature vector for the context (the top hdden layer of a neural network), and an embeddng vector for the class. Computng the dot products of many feature vectors wth many embeddng vectors s a matrx multplcaton, whch s hghly effcent on modern hardware, especally on GPUs. Batchng lke ths often allows us to use hundreds or thousands of sampled classes wthout notceable slowdown. Another way to see t s that the overhead of fetchng a class embeddng across devces s greater than the tme t takes to compute ts dot products wth hundreds or even thousands of feature vectors. So f we are gong to use a sampled class wth one context, t s vrtually free to use t wth all of the other contexts n the batch as well.