1   
  2   
  3  """ 
  4  Tests for thread usage in lxml.etree. 
  5  """ 
  6   
  7  import re 
  8  import sys 
  9  import os.path 
 10  import unittest 
 11  import threading 
 12   
 13  this_dir = os.path.dirname(__file__) 
 14  if this_dir not in sys.path: 
 15      sys.path.insert(0, this_dir)  
 16   
 17  from common_imports import etree, HelperTestCase, BytesIO, _bytes 
 18   
 19  try: 
 20      from Queue import Queue 
 21  except ImportError: 
 22      from queue import Queue  
 23   
 24   
 26      """Threading tests""" 
 27      etree = etree 
 28   
 30          thread = threading.Thread(target=func) 
 31          thread.start() 
 32          thread.join() 
  33   
 35          sync = threading.Event() 
 36          lock = threading.Lock() 
 37          counter = dict(started=0, finished=0, failed=0) 
 38   
 39          def sync_start(func): 
 40              with lock: 
 41                  started = counter['started'] + 1 
 42                  counter['started'] = started 
 43              if started < count + (main_func is not None): 
 44                  sync.wait(4)   
 45                  assert sync.is_set() 
 46              sync.set()   
 47              try: 
 48                  func() 
 49              except: 
 50                  with lock: 
 51                      counter['failed'] += 1 
 52                  raise 
 53              else: 
 54                  with lock: 
 55                      counter['finished'] += 1 
  56   
 57          threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)] 
 58          for thread in threads: 
 59              thread.start() 
 60          if main_func is not None: 
 61              sync_start(main_func) 
 62          for thread in threads: 
 63              thread.join() 
 64   
 65          self.assertEqual(0, counter['failed']) 
 66          self.assertEqual(counter['finished'], counter['started']) 
  67   
 78   
 79          self._run_thread(run_thread) 
 80          self.assertEqual(xml, tostring(main_root)) 
 81   
 83          XML = self.etree.XML 
 84          style = XML(_bytes('''\ 
 85  <xsl:stylesheet version="1.0" 
 86      xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
 87    <xsl:template match="*"> 
 88      <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 
 89    </xsl:template> 
 90  </xsl:stylesheet>''')) 
 91          st = etree.XSLT(style) 
 92   
 93          result = [] 
 94   
 95          def run_thread(): 
 96              root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 
 97              result.append( st(root) ) 
  98   
 99          self._run_thread(run_thread) 
100          self.assertEqual('''\ 
101  <?xml version="1.0"?> 
102  <foo><a>B</a></foo> 
103  ''', 
104                            str(result[0])) 
105   
121   
122          self._run_thread(run_thread) 
123          self.assertEqual(_bytes('<a><b>B</b><c>C</c><foo><a>B</a></foo></a>'), 
124                            tostring(root)) 
125   
127           
128           
129          XML = self.etree.XML 
130          tostring = self.etree.tostring 
131          style = self.etree.XSLT(XML(_bytes('''\ 
132      <xsl:stylesheet version="1.0" 
133          xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
134        <xsl:template match="*"> 
135          <root class="abc"> 
136            <xsl:copy-of select="@class" /> 
137            <xsl:attribute name="class">xyz</xsl:attribute>  
138          </root> 
139        </xsl:template> 
140      </xsl:stylesheet>'''))) 
141   
142          result = [] 
143          def run_thread(): 
144              root = XML(_bytes('<ROOT class="ABC" />')) 
145              result.append( style(root).getroot() ) 
 146   
147          self._run_thread(run_thread) 
148          self.assertEqual(_bytes('<root class="xyz"/>'), 
149                            tostring(result[0])) 
150   
152          XML = self.etree.XML 
153          tostring = self.etree.tostring 
154          root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 
155   
156          stylesheets = [] 
157   
158          def run_thread(): 
159              style = XML(_bytes('''\ 
160      <xsl:stylesheet 
161          xmlns:xsl="http://www.w3.org/1999/XSL/Transform" 
162          version="1.0"> 
163        <xsl:output method="xml" /> 
164        <xsl:template match="/"> 
165           <div id="test"> 
166             <xsl:apply-templates/> 
167           </div> 
168        </xsl:template> 
169      </xsl:stylesheet>''')) 
170              stylesheets.append( etree.XSLT(style) ) 
 171   
172          self._run_thread(run_thread) 
173   
174          st = stylesheets[0] 
175          result = tostring( st(root) ) 
176   
177          self.assertEqual(_bytes('<div id="test">BC</div>'), 
178                            result) 
179   
203   
204          self.etree.clear_error_log() 
205          threads = [] 
206          for thread_no in range(1, 10): 
207              t = threading.Thread(target=parse_error_test, 
208                                   args=(thread_no,)) 
209              threads.append(t) 
210              t.start() 
211   
212          parse_error_test(0) 
213   
214          for t in threads: 
215              t.join() 
216   
232   
233          def run_parse(): 
234              thread_root = self.etree.parse(BytesIO(xml)).getroot() 
235              result.append(thread_root[0]) 
236              result.append(thread_root[-1]) 
237   
238          def run_move_main(): 
239              result.append(fragment[0]) 
240   
241          def run_build(): 
242              result.append( 
243                  Element("{myns}foo", attrib={'{test}attr':'val'})) 
244              SubElement(result, "{otherns}tasty") 
245   
246          def run_xslt(): 
247              style = XML(_bytes('''\ 
248      <xsl:stylesheet version="1.0" 
249          xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
250        <xsl:template match="*"> 
251          <xsl:copy><foo><xsl:value-of select="/a/b/text()" /></foo></xsl:copy> 
252        </xsl:template> 
253      </xsl:stylesheet>''')) 
254              st = etree.XSLT(style) 
255              result.append( st(root).getroot() ) 
256   
257          for test in (run_XML, run_parse, run_move_main, run_xslt, run_build): 
258              tostring(result) 
259              self._run_thread(test) 
260   
261          self.assertEqual( 
262              _bytes('<ns0:root xmlns:ns0="myns" att="someval"><b>B</b>' 
263                     '<c xmlns="test">C</c><b>B</b><c xmlns="test">C</c><tags/>' 
264                     '<a><foo>B</foo></a>' 
265                     '<ns0:foo xmlns:ns1="test" ns1:attr="val"/>' 
266                     '<ns1:tasty xmlns:ns1="otherns"/></ns0:root>'), 
267              tostring(result)) 
268   
269          def strip_first(): 
270              root = Element("newroot") 
271              root.append(result[0]) 
272   
273          while len(result): 
274              self._run_thread(strip_first) 
275   
276          self.assertEqual( 
277              _bytes('<ns0:root xmlns:ns0="myns" att="someval"/>'), 
278              tostring(result)) 
279   
281          SubElement = self.etree.SubElement 
282          names = list('abcdefghijklmnop') 
283          runs_per_name = range(50) 
284          result_matches = re.compile( 
285              br'<thread_root>' 
286              br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+' 
287              br'</thread_root>').match 
288   
289          def testrun(): 
290              for _ in range(3): 
291                  root = self.etree.Element('thread_root') 
292                  for name in names: 
293                      tag_name = name * 5 
294                      new = [] 
295                      for _ in runs_per_name: 
296                          el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'}) 
297                          new.append(el) 
298                      for el in new: 
299                          el.set('thread_attr2_' + name, 'value2') 
300                  s = etree.tostring(root) 
301                  self.assertTrue(result_matches(s)) 
 302   
303           
304          self._run_threads(10, testrun) 
305   
306           
307          self._run_threads(10, testrun, main_func=testrun) 
308   
310          XML = self.etree.XML 
311          root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>')) 
312          child_count = len(root) 
313          def testrun(): 
314              for i in range(10000): 
315                  el = root[i%child_count] 
316                  del el 
 317          self._run_threads(10, testrun) 
318   
320          XML = self.etree.XML 
321   
322          class TestElement(etree.ElementBase): 
323              pass 
 324   
325          class MyLookup(etree.CustomElementClassLookup): 
326              repeat = range(100) 
327              def lookup(self, t, d, ns, name): 
328                  count = 0 
329                  for i in self.repeat: 
330                       
331                      count += 1 
332                  return TestElement 
333   
334          parser = self.etree.XMLParser() 
335          parser.set_element_class_lookup(MyLookup()) 
336   
337          root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'), 
338                     parser) 
339   
340          child_count = len(root) 
341          def testrun(): 
342              for i in range(1000): 
343                  el = root[i%child_count] 
344                  del el 
345          self._run_threads(10, testrun) 
346   
347   
349      """Threading tests based on a thread worker pipeline. 
350      """ 
351      etree = etree 
352      item_count = 40 
353   
354 -    class Worker(threading.Thread): 
 355 -        def __init__(self, in_queue, in_count, **kwargs): 
 356              threading.Thread.__init__(self) 
357              self.in_queue = in_queue 
358              self.in_count = in_count 
359              self.out_queue = Queue(in_count) 
360              self.__dict__.update(kwargs) 
 361   
363              get, put = self.in_queue.get, self.out_queue.put 
364              handle = self.handle 
365              for _ in range(self.in_count): 
366                  put(handle(get())) 
 367   
369              raise NotImplementedError() 
 373              return _fromstring(xml) 
 446          item_count = self.item_count 
447          xml = self.xml.replace(b'thread', b'THREAD')   
448   
449           
450          in_queue, start, last = self._build_pipeline( 
451              item_count, 
452              self.ParseWorker, 
453              self.RotateWorker, 
454              self.ReverseWorker, 
455              self.ParseAndExtendWorker, 
456              self.Validate, 
457              self.ParseAndInjectWorker, 
458              self.SerialiseWorker, 
459              xml=xml) 
460   
461           
462          put = start.in_queue.put 
463          for _ in range(item_count): 
464              put(xml) 
465   
466           
467          start.start() 
468           
469          last.join(60)   
470          self.assertEqual(item_count, last.out_queue.qsize()) 
471           
472          get = last.out_queue.get 
473          results = [get() for _ in range(item_count)] 
474   
475          comparison = results[0] 
476          for i, result in enumerate(results[1:]): 
477              self.assertEqual(comparison, result) 
 478   
480          item_count = self.item_count 
481          xml = self.xml.replace(b'thread', b'GLOBAL')   
482          XML = self.etree.XML 
483           
484          in_queue, start, last = self._build_pipeline( 
485              item_count, 
486              self.RotateWorker, 
487              self.ReverseWorker, 
488              self.ParseAndExtendWorker, 
489              self.Validate, 
490              self.SerialiseWorker, 
491              xml=xml) 
492   
493           
494          put = start.in_queue.put 
495          for _ in range(item_count): 
496              put(XML(xml)) 
497   
498           
499          start.start() 
500           
501          last.join(60)   
502          self.assertEqual(item_count, last.out_queue.qsize()) 
503           
504          get = last.out_queue.get 
505          results = [get() for _ in range(item_count)] 
506   
507          comparison = results[0] 
508          for i, result in enumerate(results[1:]): 
509              self.assertEqual(comparison, result) 
  510   
511   
517   
518  if __name__ == '__main__': 
519      print('to test use test.py %s' % __file__) 
520