1   
  2   
  3  """ 
  4  Tests for thread usage in lxml.etree. 
  5  """ 
  6   
  7  import re 
  8  import sys 
  9  import os.path 
 10  import unittest 
 11  import threading 
 12   
 13  this_dir = os.path.dirname(__file__) 
 14  if this_dir not in sys.path: 
 15      sys.path.insert(0, this_dir)  
 16   
 17  from common_imports import etree, HelperTestCase, BytesIO, _bytes 
 18   
 19  try: 
 20      from Queue import Queue 
 21  except ImportError: 
 22      from queue import Queue  
 23   
 24   
 26      """Threading tests""" 
 27      etree = etree 
 28   
 30          thread = threading.Thread(target=func) 
 31          thread.start() 
 32          thread.join() 
  33   
 35          sync = threading.Event() 
 36          lock = threading.Lock() 
 37          counter = dict(started=0, finished=0, failed=0) 
 38   
 39          def sync_start(func): 
 40              with lock: 
 41                  started = counter['started'] + 1 
 42                  counter['started'] = started 
 43              if started < count + (main_func is not None): 
 44                  sync.wait(4)   
 45                  assert sync.is_set() 
 46              sync.set()   
 47              try: 
 48                  func() 
 49              except: 
 50                  with lock: 
 51                      counter['failed'] += 1 
 52                  raise 
 53              else: 
 54                  with lock: 
 55                      counter['finished'] += 1 
  56   
 57          threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)] 
 58          for thread in threads: 
 59              thread.start() 
 60          if main_func is not None: 
 61              sync_start(main_func) 
 62          for thread in threads: 
 63              thread.join() 
 64   
 65          self.assertEqual(0, counter['failed']) 
 66          self.assertEqual(counter['finished'], counter['started']) 
  67   
 78   
 79          self._run_thread(run_thread) 
 80          self.assertEqual(xml, tostring(main_root)) 
 81   
 83          XML = self.etree.XML 
 84          style = XML(_bytes('''\ 
 85  <xsl:stylesheet version="1.0" 
 86      xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
 87    <xsl:template match="*"> 
 88      <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 
 89    </xsl:template> 
 90  </xsl:stylesheet>''')) 
 91          st = etree.XSLT(style) 
 92   
 93          result = [] 
 94   
 95          def run_thread(): 
 96              root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 
 97              result.append( st(root) ) 
  98   
 99          self._run_thread(run_thread) 
100          self.assertEqual('''\ 
101  <?xml version="1.0"?> 
102  <foo><a>B</a></foo> 
103  ''', 
104                            str(result[0])) 
105   
121   
122          self._run_thread(run_thread) 
123          self.assertEqual(_bytes('<a><b>B</b><c>C</c><foo><a>B</a></foo></a>'), 
124                            tostring(root)) 
125   
127           
128           
129          XML = self.etree.XML 
130          tostring = self.etree.tostring 
131          style = self.etree.XSLT(XML(_bytes('''\ 
132      <xsl:stylesheet version="1.0" 
133          xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
134        <xsl:template match="*"> 
135          <root class="abc"> 
136            <xsl:copy-of select="@class" /> 
137            <xsl:attribute name="class">xyz</xsl:attribute>  
138          </root> 
139        </xsl:template> 
140      </xsl:stylesheet>'''))) 
141   
142          result = [] 
143          def run_thread(): 
144              root = XML(_bytes('<ROOT class="ABC" />')) 
145              result.append( style(root).getroot() ) 
 146   
147          self._run_thread(run_thread) 
148          self.assertEqual(_bytes('<root class="xyz"/>'), 
149                            tostring(result[0])) 
150   
152          XML = self.etree.XML 
153          tostring = self.etree.tostring 
154          root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 
155   
156          stylesheets = [] 
157   
158          def run_thread(): 
159              style = XML(_bytes('''\ 
160      <xsl:stylesheet 
161          xmlns:xsl="http://www.w3.org/1999/XSL/Transform" 
162          version="1.0"> 
163        <xsl:output method="xml" /> 
164        <xsl:template match="/"> 
165           <div id="test"> 
166             <xsl:apply-templates/> 
167           </div> 
168        </xsl:template> 
169      </xsl:stylesheet>''')) 
170              stylesheets.append( etree.XSLT(style) ) 
 171   
172          self._run_thread(run_thread) 
173   
174          st = stylesheets[0] 
175          result = tostring( st(root) ) 
176   
177          self.assertEqual(_bytes('<div id="test">BC</div>'), 
178                            result) 
179   
202   
203          self.etree.clear_error_log() 
204          threads = [] 
205          for thread_no in range(1, 10): 
206              t = threading.Thread(target=parse_error_test, 
207                                   args=(thread_no,)) 
208              threads.append(t) 
209              t.start() 
210   
211          parse_error_test(0) 
212   
213          for t in threads: 
214              t.join() 
215   
231   
232          def run_parse(): 
233              thread_root = self.etree.parse(BytesIO(xml)).getroot() 
234              result.append(thread_root[0]) 
235              result.append(thread_root[-1]) 
236   
237          def run_move_main(): 
238              result.append(fragment[0]) 
239   
240          def run_build(): 
241              result.append( 
242                  Element("{myns}foo", attrib={'{test}attr':'val'})) 
243              SubElement(result, "{otherns}tasty") 
244   
245          def run_xslt(): 
246              style = XML(_bytes('''\ 
247      <xsl:stylesheet version="1.0" 
248          xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
249        <xsl:template match="*"> 
250          <xsl:copy><foo><xsl:value-of select="/a/b/text()" /></foo></xsl:copy> 
251        </xsl:template> 
252      </xsl:stylesheet>''')) 
253              st = etree.XSLT(style) 
254              result.append( st(root).getroot() ) 
255   
256          for test in (run_XML, run_parse, run_move_main, run_xslt, run_build): 
257              tostring(result) 
258              self._run_thread(test) 
259   
260          self.assertEqual( 
261              _bytes('<ns0:root xmlns:ns0="myns" att="someval"><b>B</b>' 
262                     '<c xmlns="test">C</c><b>B</b><c xmlns="test">C</c><tags/>' 
263                     '<a><foo>B</foo></a>' 
264                     '<ns0:foo xmlns:ns1="test" ns1:attr="val"/>' 
265                     '<ns1:tasty xmlns:ns1="otherns"/></ns0:root>'), 
266              tostring(result)) 
267   
268          def strip_first(): 
269              root = Element("newroot") 
270              root.append(result[0]) 
271   
272          while len(result): 
273              self._run_thread(strip_first) 
274   
275          self.assertEqual( 
276              _bytes('<ns0:root xmlns:ns0="myns" att="someval"/>'), 
277              tostring(result)) 
278   
280          SubElement = self.etree.SubElement 
281          names = list('abcdefghijklmnop') 
282          runs_per_name = range(50) 
283          result_matches = re.compile( 
284              br'<thread_root>' 
285              br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+' 
286              br'</thread_root>').match 
287   
288          def testrun(): 
289              for _ in range(3): 
290                  root = self.etree.Element('thread_root') 
291                  for name in names: 
292                      tag_name = name * 5 
293                      new = [] 
294                      for _ in runs_per_name: 
295                          el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'}) 
296                          new.append(el) 
297                      for el in new: 
298                          el.set('thread_attr2_' + name, 'value2') 
299                  s = etree.tostring(root) 
300                  self.assertTrue(result_matches(s)) 
 301   
302           
303          self._run_threads(10, testrun) 
304   
305           
306          self._run_threads(10, testrun, main_func=testrun) 
307   
309          XML = self.etree.XML 
310          root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>')) 
311          child_count = len(root) 
312          def testrun(): 
313              for i in range(10000): 
314                  el = root[i%child_count] 
315                  del el 
 316          self._run_threads(10, testrun) 
317   
319          XML = self.etree.XML 
320   
321          class TestElement(etree.ElementBase): 
322              pass 
 323   
324          class MyLookup(etree.CustomElementClassLookup): 
325              repeat = range(100) 
326              def lookup(self, t, d, ns, name): 
327                  count = 0 
328                  for i in self.repeat: 
329                       
330                      count += 1 
331                  return TestElement 
332   
333          parser = self.etree.XMLParser() 
334          parser.set_element_class_lookup(MyLookup()) 
335   
336          root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'), 
337                     parser) 
338   
339          child_count = len(root) 
340          def testrun(): 
341              for i in range(1000): 
342                  el = root[i%child_count] 
343                  del el 
344          self._run_threads(10, testrun) 
345   
346   
348      """Threading tests based on a thread worker pipeline. 
349      """ 
350      etree = etree 
351      item_count = 40 
352   
353 -    class Worker(threading.Thread): 
 354 -        def __init__(self, in_queue, in_count, **kwargs): 
 355              threading.Thread.__init__(self) 
356              self.in_queue = in_queue 
357              self.in_count = in_count 
358              self.out_queue = Queue(in_count) 
359              self.__dict__.update(kwargs) 
 360   
362              get, put = self.in_queue.get, self.out_queue.put 
363              handle = self.handle 
364              for _ in range(self.in_count): 
365                  put(handle(get())) 
 366   
368              raise NotImplementedError() 
 372              return _fromstring(xml) 
 445          item_count = self.item_count 
446          xml = self.xml.replace(b'thread', b'THREAD')   
447   
448           
449          in_queue, start, last = self._build_pipeline( 
450              item_count, 
451              self.ParseWorker, 
452              self.RotateWorker, 
453              self.ReverseWorker, 
454              self.ParseAndExtendWorker, 
455              self.Validate, 
456              self.ParseAndInjectWorker, 
457              self.SerialiseWorker, 
458              xml=xml) 
459   
460           
461          put = start.in_queue.put 
462          for _ in range(item_count): 
463              put(xml) 
464   
465           
466          start.start() 
467           
468          last.join(60)   
469          self.assertEqual(item_count, last.out_queue.qsize()) 
470           
471          get = last.out_queue.get 
472          results = [get() for _ in range(item_count)] 
473   
474          comparison = results[0] 
475          for i, result in enumerate(results[1:]): 
476              self.assertEqual(comparison, result) 
 477   
479          item_count = self.item_count 
480          xml = self.xml.replace(b'thread', b'GLOBAL')   
481          XML = self.etree.XML 
482           
483          in_queue, start, last = self._build_pipeline( 
484              item_count, 
485              self.RotateWorker, 
486              self.ReverseWorker, 
487              self.ParseAndExtendWorker, 
488              self.Validate, 
489              self.SerialiseWorker, 
490              xml=xml) 
491   
492           
493          put = start.in_queue.put 
494          for _ in range(item_count): 
495              put(XML(xml)) 
496   
497           
498          start.start() 
499           
500          last.join(60)   
501          self.assertEqual(item_count, last.out_queue.qsize()) 
502           
503          get = last.out_queue.get 
504          results = [get() for _ in range(item_count)] 
505   
506          comparison = results[0] 
507          for i, result in enumerate(results[1:]): 
508              self.assertEqual(comparison, result) 
  509   
510   
516   
517  if __name__ == '__main__': 
518      print('to test use test.py %s' % __file__) 
519