1   
  2   
  3  """ 
  4  Tests for thread usage in lxml.etree. 
  5  """ 
  6   
  7  import re 
  8  import sys 
  9  import os.path 
 10  import unittest 
 11  import threading 
 12   
 13  this_dir = os.path.dirname(__file__) 
 14  if this_dir not in sys.path: 
 15      sys.path.insert(0, this_dir)  
 16   
 17  from common_imports import etree, HelperTestCase, BytesIO, _bytes 
 18   
 19  try: 
 20      from Queue import Queue 
 21  except ImportError: 
 22      from queue import Queue  
 23   
 24   
 26      """Threading tests""" 
 27      etree = etree 
 28   
 30          thread = threading.Thread(target=func) 
 31          thread.start() 
 32          thread.join() 
  33   
 35          sync = threading.Event() 
 36          lock = threading.Lock() 
 37          counter = dict(started=0, finished=0, failed=0) 
 38   
 39          def sync_start(func): 
 40              with lock: 
 41                  started = counter['started'] + 1 
 42                  counter['started'] = started 
 43              if started < count + (main_func is not None): 
 44                  sync.wait(4)   
 45                  assert sync.is_set() 
 46              sync.set()   
 47              try: 
 48                  func() 
 49              except: 
 50                  with lock: 
 51                      counter['failed'] += 1 
 52                  raise 
 53              else: 
 54                  with lock: 
 55                      counter['finished'] += 1 
  56   
 57          threads = [threading.Thread(target=sync_start, args=(func,)) for _ in range(count)] 
 58          for thread in threads: 
 59              thread.start() 
 60          if main_func is not None: 
 61              sync_start(main_func) 
 62          for thread in threads: 
 63              thread.join() 
 64   
 65          self.assertEqual(0, counter['failed']) 
 66          self.assertEqual(counter['finished'], counter['started']) 
  67   
 78   
 79          self._run_thread(run_thread) 
 80          self.assertEqual(xml, tostring(main_root)) 
 81   
 83          XML = self.etree.XML 
 84          style = XML(_bytes('''\ 
 85  <xsl:stylesheet version="1.0" 
 86      xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
 87    <xsl:template match="*"> 
 88      <foo><xsl:copy><xsl:value-of select="/a/b/text()" /></xsl:copy></foo> 
 89    </xsl:template> 
 90  </xsl:stylesheet>''')) 
 91          st = etree.XSLT(style) 
 92   
 93          result = [] 
 94   
 95          def run_thread(): 
 96              root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 
 97              result.append( st(root) ) 
  98   
 99          self._run_thread(run_thread) 
100          self.assertEqual('''\ 
101  <?xml version="1.0"?> 
102  <foo><a>B</a></foo> 
103  ''', 
104                            str(result[0])) 
105   
121   
122          self._run_thread(run_thread) 
123          self.assertEqual(_bytes('<a><b>B</b><c>C</c><foo><a>B</a></foo></a>'), 
124                            tostring(root)) 
125   
127          style = self.parse('''\ 
128  <xsl:stylesheet version="1.0" 
129      xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
130      <xsl:template match="tag" /> 
131      <!-- extend time for parsing + transform --> 
132  ''' + '\n'.join('<xsl:template match="tag%x" />' % i for i in range(200)) + ''' 
133      <xsl:foo /> 
134  </xsl:stylesheet>''') 
135          self.assertRaises(etree.XSLTParseError, 
136                            etree.XSLT, style) 
137   
138          error_logs = [] 
139   
140          def run_thread(): 
141              try: 
142                  etree.XSLT(style) 
143              except etree.XSLTParseError as e: 
144                  error_logs.append(e.error_log) 
145              else: 
146                  self.assertFalse(True, "XSLT parsing should have failed but didn't") 
 147   
148          self._run_threads(16, run_thread) 
149   
150          self.assertEqual(16, len(error_logs)) 
151          last_log = None 
152          for log in error_logs: 
153              self.assertTrue(len(log)) 
154              if last_log is not None: 
155                  self.assertEqual(len(last_log), len(log)) 
156              self.assertEqual(4, len(log)) 
157              for error in log: 
158                  self.assertTrue(':ERROR:XSLT:' in str(error)) 
159              last_log = log 
160   
162          tree = self.parse('<tagFF/>') 
163          style = self.parse('''\ 
164  <xsl:stylesheet version="1.0" 
165      xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
166      <xsl:template name="tag0"> 
167          <xsl:message terminate="yes">FAIL</xsl:message> 
168      </xsl:template> 
169      <!-- extend time for parsing + transform --> 
170  ''' + '\n'.join('<xsl:template match="tag%X" name="tag%x"> <xsl:call-template name="tag%x" /> </xsl:template>' % (i, i, i-1) 
171                  for i in range(1, 256)) + ''' 
172  </xsl:stylesheet>''') 
173          self.assertRaises(etree.XSLTApplyError, 
174                            etree.XSLT(style), tree) 
175   
176          error_logs = [] 
177   
178          def run_thread(): 
179              transform = etree.XSLT(style) 
180              try: 
181                  transform(tree) 
182              except etree.XSLTApplyError: 
183                  error_logs.append(transform.error_log) 
184              else: 
185                  self.assertFalse(True, "XSLT parsing should have failed but didn't") 
 186   
187          self._run_threads(16, run_thread) 
188   
189          self.assertEqual(16, len(error_logs)) 
190          last_log = None 
191          for log in error_logs: 
192              self.assertTrue(len(log)) 
193              if last_log is not None: 
194                  self.assertEqual(len(last_log), len(log)) 
195              self.assertEqual(1, len(log)) 
196              for error in log: 
197                  self.assertTrue(':ERROR:XSLT:' in str(error)) 
198              last_log = log 
199   
201           
202           
203          XML = self.etree.XML 
204          tostring = self.etree.tostring 
205          style = self.etree.XSLT(XML(_bytes('''\ 
206      <xsl:stylesheet version="1.0" 
207          xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
208        <xsl:template match="*"> 
209          <root class="abc"> 
210            <xsl:copy-of select="@class" /> 
211            <xsl:attribute name="class">xyz</xsl:attribute>  
212          </root> 
213        </xsl:template> 
214      </xsl:stylesheet>'''))) 
215   
216          result = [] 
217          def run_thread(): 
218              root = XML(_bytes('<ROOT class="ABC" />')) 
219              result.append( style(root).getroot() ) 
 220   
221          self._run_thread(run_thread) 
222          self.assertEqual(_bytes('<root class="xyz"/>'), 
223                            tostring(result[0])) 
224   
226          XML = self.etree.XML 
227          tostring = self.etree.tostring 
228          root = XML(_bytes('<a><b>B</b><c>C</c></a>')) 
229   
230          stylesheets = [] 
231   
232          def run_thread(): 
233              style = XML(_bytes('''\ 
234      <xsl:stylesheet 
235          xmlns:xsl="http://www.w3.org/1999/XSL/Transform" 
236          version="1.0"> 
237        <xsl:output method="xml" /> 
238        <xsl:template match="/"> 
239           <div id="test"> 
240             <xsl:apply-templates/> 
241           </div> 
242        </xsl:template> 
243      </xsl:stylesheet>''')) 
244              stylesheets.append( etree.XSLT(style) ) 
 245   
246          self._run_thread(run_thread) 
247   
248          st = stylesheets[0] 
249          result = tostring( st(root) ) 
250   
251          self.assertEqual(_bytes('<div id="test">BC</div>'), 
252                            result) 
253   
276   
277          self.etree.clear_error_log() 
278          threads = [] 
279          for thread_no in range(1, 10): 
280              t = threading.Thread(target=parse_error_test, 
281                                   args=(thread_no,)) 
282              threads.append(t) 
283              t.start() 
284   
285          parse_error_test(0) 
286   
287          for t in threads: 
288              t.join() 
289   
305   
306          def run_parse(): 
307              thread_root = self.etree.parse(BytesIO(xml)).getroot() 
308              result.append(thread_root[0]) 
309              result.append(thread_root[-1]) 
310   
311          def run_move_main(): 
312              result.append(fragment[0]) 
313   
314          def run_build(): 
315              result.append( 
316                  Element("{myns}foo", attrib={'{test}attr':'val'})) 
317              SubElement(result, "{otherns}tasty") 
318   
319          def run_xslt(): 
320              style = XML(_bytes('''\ 
321      <xsl:stylesheet version="1.0" 
322          xmlns:xsl="http://www.w3.org/1999/XSL/Transform"> 
323        <xsl:template match="*"> 
324          <xsl:copy><foo><xsl:value-of select="/a/b/text()" /></foo></xsl:copy> 
325        </xsl:template> 
326      </xsl:stylesheet>''')) 
327              st = etree.XSLT(style) 
328              result.append( st(root).getroot() ) 
329   
330          for test in (run_XML, run_parse, run_move_main, run_xslt, run_build): 
331              tostring(result) 
332              self._run_thread(test) 
333   
334          self.assertEqual( 
335              _bytes('<ns0:root xmlns:ns0="myns" att="someval"><b>B</b>' 
336                     '<c xmlns="test">C</c><b>B</b><c xmlns="test">C</c><tags/>' 
337                     '<a><foo>B</foo></a>' 
338                     '<ns0:foo xmlns:ns1="test" ns1:attr="val"/>' 
339                     '<ns1:tasty xmlns:ns1="otherns"/></ns0:root>'), 
340              tostring(result)) 
341   
342          def strip_first(): 
343              root = Element("newroot") 
344              root.append(result[0]) 
345   
346          while len(result): 
347              self._run_thread(strip_first) 
348   
349          self.assertEqual( 
350              _bytes('<ns0:root xmlns:ns0="myns" att="someval"/>'), 
351              tostring(result)) 
352   
354          SubElement = self.etree.SubElement 
355          names = list('abcdefghijklmnop') 
356          runs_per_name = range(50) 
357          result_matches = re.compile( 
358              br'<thread_root>' 
359              br'(?:<[a-p]{5} thread_attr_[a-p]="value" thread_attr2_[a-p]="value2"\s?/>)+' 
360              br'</thread_root>').match 
361   
362          def testrun(): 
363              for _ in range(3): 
364                  root = self.etree.Element('thread_root') 
365                  for name in names: 
366                      tag_name = name * 5 
367                      new = [] 
368                      for _ in runs_per_name: 
369                          el = SubElement(root, tag_name, {'thread_attr_' + name: 'value'}) 
370                          new.append(el) 
371                      for el in new: 
372                          el.set('thread_attr2_' + name, 'value2') 
373                  s = etree.tostring(root) 
374                  self.assertTrue(result_matches(s)) 
 375   
376           
377          self._run_threads(10, testrun) 
378   
379           
380          self._run_threads(10, testrun, main_func=testrun) 
381   
383          XML = self.etree.XML 
384          root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>')) 
385          child_count = len(root) 
386          def testrun(): 
387              for i in range(10000): 
388                  el = root[i%child_count] 
389                  del el 
 390          self._run_threads(10, testrun) 
391   
393          XML = self.etree.XML 
394   
395          class TestElement(etree.ElementBase): 
396              pass 
 397   
398          class MyLookup(etree.CustomElementClassLookup): 
399              repeat = range(100) 
400              def lookup(self, t, d, ns, name): 
401                  count = 0 
402                  for i in self.repeat: 
403                       
404                      count += 1 
405                  return TestElement 
406   
407          parser = self.etree.XMLParser() 
408          parser.set_element_class_lookup(MyLookup()) 
409   
410          root = XML(_bytes('<root><a>A</a><b xmlns="test">B</b><c/></root>'), 
411                     parser) 
412   
413          child_count = len(root) 
414          def testrun(): 
415              for i in range(1000): 
416                  el = root[i%child_count] 
417                  del el 
418          self._run_threads(10, testrun) 
419   
420   
422      """Threading tests based on a thread worker pipeline. 
423      """ 
424      etree = etree 
425      item_count = 40 
426   
427 -    class Worker(threading.Thread): 
 428 -        def __init__(self, in_queue, in_count, **kwargs): 
 429              threading.Thread.__init__(self) 
430              self.in_queue = in_queue 
431              self.in_count = in_count 
432              self.out_queue = Queue(in_count) 
433              self.__dict__.update(kwargs) 
 434   
436              get, put = self.in_queue.get, self.out_queue.put 
437              handle = self.handle 
438              for _ in range(self.in_count): 
439                  put(handle(get())) 
 440   
442              raise NotImplementedError() 
 446              return _fromstring(xml) 
 519          item_count = self.item_count 
520          xml = self.xml.replace(b'thread', b'THREAD')   
521   
522           
523          in_queue, start, last = self._build_pipeline( 
524              item_count, 
525              self.ParseWorker, 
526              self.RotateWorker, 
527              self.ReverseWorker, 
528              self.ParseAndExtendWorker, 
529              self.Validate, 
530              self.ParseAndInjectWorker, 
531              self.SerialiseWorker, 
532              xml=xml) 
533   
534           
535          put = start.in_queue.put 
536          for _ in range(item_count): 
537              put(xml) 
538   
539           
540          start.start() 
541           
542          last.join(60)   
543          self.assertEqual(item_count, last.out_queue.qsize()) 
544           
545          get = last.out_queue.get 
546          results = [get() for _ in range(item_count)] 
547   
548          comparison = results[0] 
549          for i, result in enumerate(results[1:]): 
550              self.assertEqual(comparison, result) 
 551   
553          item_count = self.item_count 
554          xml = self.xml.replace(b'thread', b'GLOBAL')   
555          XML = self.etree.XML 
556           
557          in_queue, start, last = self._build_pipeline( 
558              item_count, 
559              self.RotateWorker, 
560              self.ReverseWorker, 
561              self.ParseAndExtendWorker, 
562              self.Validate, 
563              self.SerialiseWorker, 
564              xml=xml) 
565   
566           
567          put = start.in_queue.put 
568          for _ in range(item_count): 
569              put(XML(xml)) 
570   
571           
572          start.start() 
573           
574          last.join(60)   
575          self.assertEqual(item_count, last.out_queue.qsize()) 
576           
577          get = last.out_queue.get 
578          results = [get() for _ in range(item_count)] 
579   
580          comparison = results[0] 
581          for i, result in enumerate(results[1:]): 
582              self.assertEqual(comparison, result) 
  583   
584   
590   
591  if __name__ == '__main__': 
592      print('to test use test.py %s' % __file__) 
593