[jsinterp] Handle NaN in bitwise operators

* also add _NaN
* also pull function naming from yt-dlp
This commit is contained in:
dirkf 2023-05-11 20:59:30 +01:00
parent 11cc3f3ad0
commit a85a875fef
2 changed files with 43 additions and 9 deletions

View File

@ -18,6 +18,7 @@ class TestJSInterpreter(unittest.TestCase):
def test_basic(self): def test_basic(self):
jsi = JSInterpreter('function x(){;}') jsi = JSInterpreter('function x(){;}')
self.assertEqual(jsi.call_function('x'), None) self.assertEqual(jsi.call_function('x'), None)
self.assertEqual(repr(jsi.extract_function('x')), 'F<x>')
jsi = JSInterpreter('function x3(){return 42;}') jsi = JSInterpreter('function x3(){return 42;}')
self.assertEqual(jsi.call_function('x3'), 42) self.assertEqual(jsi.call_function('x3'), 42)
@ -505,6 +506,16 @@ class TestJSInterpreter(unittest.TestCase):
jsi = JSInterpreter('function x(){return 1236566549 << 5}') jsi = JSInterpreter('function x(){return 1236566549 << 5}')
self.assertEqual(jsi.call_function('x'), 915423904) self.assertEqual(jsi.call_function('x'), 915423904)
def test_bitwise_operators_madness(self):
jsi = JSInterpreter('function x(){return null << 5}')
self.assertEqual(jsi.call_function('x'), 0)
jsi = JSInterpreter('function x(){return undefined >> 5}')
self.assertEqual(jsi.call_function('x'), 0)
jsi = JSInterpreter('function x(){return 42 << NaN}')
self.assertEqual(jsi.call_function('x'), 42)
def test_32066(self): def test_32066(self):
jsi = JSInterpreter("function x(){return Math.pow(3, 5) + new Date('1970-01-01T08:01:42.000+08:00') / 1000 * -239 - -24205;}") jsi = JSInterpreter("function x(){return Math.pow(3, 5) + new Date('1970-01-01T08:01:42.000+08:00') / 1000 * -239 - -24205;}")
self.assertEqual(jsi.call_function('x'), 70) self.assertEqual(jsi.call_function('x'), 70)

View File

@ -1,12 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from functools import update_wrapper
import itertools import itertools
import json import json
import math import math
import operator import operator
import re import re
from functools import update_wrapper
from .utils import ( from .utils import (
error_to_compat_str, error_to_compat_str,
ExtractorError, ExtractorError,
@ -24,6 +25,22 @@ from .compat import (
) )
# name JS functions
class function_with_repr(object):
# from yt_dlp/utils.py, but in this module
# repr_ is always set
def __init__(self, func, repr_):
update_wrapper(self, func)
self.func, self.__repr = func, repr_
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __repr__(self):
return self.__repr
# name JS operators
def wraps_op(op): def wraps_op(op):
def update_and_rename_wrapper(w): def update_and_rename_wrapper(w):
@ -35,10 +52,13 @@ def wraps_op(op):
return update_and_rename_wrapper return update_and_rename_wrapper
_NaN = float('nan')
def _js_bit_op(op): def _js_bit_op(op):
def zeroise(x): def zeroise(x):
return 0 if x in (None, JS_Undefined) else x return 0 if x in (None, JS_Undefined, _NaN) else x
@wraps_op(op) @wraps_op(op)
def wrapped(a, b): def wrapped(a, b):
@ -52,7 +72,7 @@ def _js_arith_op(op):
@wraps_op(op) @wraps_op(op)
def wrapped(a, b): def wrapped(a, b):
if JS_Undefined in (a, b): if JS_Undefined in (a, b):
return float('nan') return _NaN
return op(a or 0, b or 0) return op(a or 0, b or 0)
return wrapped return wrapped
@ -60,13 +80,13 @@ def _js_arith_op(op):
def _js_div(a, b): def _js_div(a, b):
if JS_Undefined in (a, b) or not (a and b): if JS_Undefined in (a, b) or not (a and b):
return float('nan') return _NaN
return operator.truediv(a or 0, b) if b else float('inf') return operator.truediv(a or 0, b) if b else float('inf')
def _js_mod(a, b): def _js_mod(a, b):
if JS_Undefined in (a, b) or not b: if JS_Undefined in (a, b) or not b:
return float('nan') return _NaN
return (a or 0) % b return (a or 0) % b
@ -74,7 +94,7 @@ def _js_exp(a, b):
if not b: if not b:
return 1 # even 0 ** 0 !! return 1 # even 0 ** 0 !!
elif JS_Undefined in (a, b): elif JS_Undefined in (a, b):
return float('nan') return _NaN
return (a or 0) ** b return (a or 0) ** b
@ -285,6 +305,8 @@ class JSInterpreter(object):
def _named_object(self, namespace, obj): def _named_object(self, namespace, obj):
self.__named_object_counter += 1 self.__named_object_counter += 1
name = '%s%d' % (self._OBJ_NAME, self.__named_object_counter) name = '%s%d' % (self._OBJ_NAME, self.__named_object_counter)
if callable(obj) and not isinstance(obj, function_with_repr):
obj = function_with_repr(obj, 'F<%s>' % (self.__named_object_counter, ))
namespace[name] = obj namespace[name] = obj
return name return name
@ -693,7 +715,7 @@ class JSInterpreter(object):
elif expr == 'undefined': elif expr == 'undefined':
return JS_Undefined, should_return return JS_Undefined, should_return
elif expr == 'NaN': elif expr == 'NaN':
return float('NaN'), should_return return _NaN, should_return
elif md.get('return'): elif md.get('return'):
return local_vars[m.group('name')], should_return return local_vars[m.group('name')], should_return
@ -953,7 +975,9 @@ class JSInterpreter(object):
return self.build_arglist(func_m.group('args')), code return self.build_arglist(func_m.group('args')), code
def extract_function(self, funcname): def extract_function(self, funcname):
return self.extract_function_from_code(*self.extract_function_code(funcname)) return function_with_repr(
self.extract_function_from_code(*self.extract_function_code(funcname)),
'F<%s>' % (funcname, ))
def extract_function_from_code(self, argnames, code, *global_stack): def extract_function_from_code(self, argnames, code, *global_stack):
local_vars = {} local_vars = {}
@ -988,7 +1012,6 @@ class JSInterpreter(object):
def build_function(self, argnames, code, *global_stack): def build_function(self, argnames, code, *global_stack):
global_stack = list(global_stack) or [{}] global_stack = list(global_stack) or [{}]
argnames = tuple(argnames) argnames = tuple(argnames)
# import pdb; pdb.set_trace()
def resf(args, kwargs={}, allow_recursion=100): def resf(args, kwargs={}, allow_recursion=100):
global_stack[0].update( global_stack[0].update(