# Copyright (c) Microsoft Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import math import os import re import time import traceback from pathlib import Path from types import TracebackType from typing import ( TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Pattern, Set, TypedDict, TypeVar, Union, cast, ) from urllib.parse import urljoin, urlparse from playwright._impl._api_structures import NameValue from playwright._impl._errors import ( Error, TargetClosedError, TimeoutError, is_target_closed_error, rewrite_error, ) from playwright._impl._glob import glob_to_regex_pattern from playwright._impl._greenlets import RouteGreenlet from playwright._impl._str_utils import escape_regex_flags if TYPE_CHECKING: # pragma: no cover from playwright._impl._api_structures import HeadersArray from playwright._impl._network import Request, Response, Route, WebSocketRoute URLMatch = Union[str, Pattern[str], Callable[[str], bool]] URLMatchRequest = Union[str, Pattern[str], Callable[["Request"], bool]] URLMatchResponse = Union[str, Pattern[str], Callable[["Response"], bool]] RouteHandlerCallback = Union[ Callable[["Route"], Any], Callable[["Route", "Request"], Any] ] WebSocketRouteHandlerCallback = Callable[["WebSocketRoute"], Any] ColorScheme = Literal["dark", "light", "no-preference", "null"] ForcedColors = Literal["active", "none", "null"] Contrast = Literal["more", "no-preference", "null"] ReducedMotion = Literal["no-preference", "null", "reduce"] DocumentLoadState = Literal["commit", "domcontentloaded", "load", "networkidle"] KeyboardModifier = Literal["Alt", "Control", "ControlOrMeta", "Meta", "Shift"] MouseButton = Literal["left", "middle", "right"] ServiceWorkersPolicy = Literal["allow", "block"] HarMode = Literal["full", "minimal"] HarContentPolicy = Literal["attach", "embed", "omit"] RouteFromHarNotFoundPolicy = Literal["abort", "fallback"] class ErrorPayload(TypedDict, total=False): message: str name: str stack: str value: Optional[Any] class HarRecordingMetadata(TypedDict, total=False): path: str content: Optional[HarContentPolicy] def prepare_record_har_options(params: Dict) -> Dict[str, Any]: out_params: Dict[str, Any] = {"path": str(params["recordHarPath"])} if "recordHarUrlFilter" in params: opt = params["recordHarUrlFilter"] if isinstance(opt, str): out_params["urlGlob"] = opt if isinstance(opt, Pattern): out_params["urlRegexSource"] = opt.pattern out_params["urlRegexFlags"] = escape_regex_flags(opt) del params["recordHarUrlFilter"] if "recordHarMode" in params: out_params["mode"] = params["recordHarMode"] del params["recordHarMode"] new_content_api = None old_content_api = None if "recordHarContent" in params: new_content_api = params["recordHarContent"] del params["recordHarContent"] if "recordHarOmitContent" in params: old_content_api = params["recordHarOmitContent"] del params["recordHarOmitContent"] content = new_content_api or ("omit" if old_content_api else None) if content: out_params["content"] = content return out_params class ParsedMessageParams(TypedDict): type: str guid: str initializer: Dict class ParsedMessagePayload(TypedDict, total=False): id: int guid: str method: str params: ParsedMessageParams result: Any error: ErrorPayload class Document(TypedDict): request: Optional[Any] class FrameNavigatedEvent(TypedDict): url: str name: str newDocument: Optional[Document] error: Optional[str] Env = Dict[str, Union[str, float, bool]] def url_matches( base_url: Optional[str], url_string: str, match: Optional[URLMatch], websocket_url: bool = None, ) -> bool: if not match: return True if isinstance(match, str): match = re.compile( resolve_glob_to_regex_pattern(base_url, match, websocket_url) ) if isinstance(match, Pattern): return bool(match.search(url_string)) return match(url_string) def resolve_glob_to_regex_pattern( base_url: Optional[str], glob: str, websocket_url: bool = None ) -> str: if websocket_url: base_url = to_websocket_base_url(base_url) glob = resolve_glob_base(base_url, glob) return glob_to_regex_pattern(glob) def to_websocket_base_url(base_url: Optional[str]) -> Optional[str]: if base_url is not None and re.match(r"^https?://", base_url): base_url = re.sub(r"^http", "ws", base_url) return base_url def resolve_glob_base(base_url: Optional[str], match: str) -> str: if match[0] == "*": return match token_map: Dict[str, str] = {} def map_token(original: str, replacement: str) -> str: if len(original) == 0: return "" token_map[replacement] = original return replacement # Escaped `\\?` behaves the same as `?` in our glob patterns. match = match.replace(r"\\?", "?") # Glob symbols may be escaped in the URL and some of them such as ? affect resolution, # so we replace them with safe components first. processed_parts = [] for index, token in enumerate(match.split("/")): if token in (".", "..", ""): processed_parts.append(token) continue # Handle special case of http*://, note that the new schema has to be # a web schema so that slashes are properly inserted after domain. if index == 0 and token.endswith(":"): # Using a simple replacement for the scheme part processed_parts.append(map_token(token, "http:")) continue question_index = token.find("?") if question_index == -1: processed_parts.append(map_token(token, f"$_{index}_$")) else: new_prefix = map_token(token[:question_index], f"$_{index}_$") new_suffix = map_token(token[question_index:], f"?$_{index}_$") processed_parts.append(new_prefix + new_suffix) relative_path = "/".join(processed_parts) resolved_url = urljoin(base_url if base_url is not None else "", relative_path) for replacement, original in token_map.items(): resolved_url = resolved_url.replace(replacement, original, 1) return ensure_trailing_slash(resolved_url) # In Node.js, new URL('http://localhost') returns 'http://localhost/'. # To ensure the same url matching behavior, do the same. def ensure_trailing_slash(url: str) -> str: split = url.split("://", maxsplit=1) if len(split) == 2: # URL parser doesn't like strange/unknown schemes, so we replace it for parsing, then put it back parsable_url = "http://" + split[1] else: # Given current rules, this should never happen _and_ still be a valid matcher. We require the protocol to be part of the match, # so either the user is using a glob that starts with "*" (and none of this code is running), or the user actually has `something://` in `match` parsable_url = url parsed = urlparse(parsable_url, allow_fragments=True) if len(split) == 2: # Replace the scheme that we removed earlier parsed = parsed._replace(scheme=split[0]) if parsed.path == "": parsed = parsed._replace(path="/") url = parsed.geturl() return url class HarLookupResult(TypedDict, total=False): action: Literal["error", "redirect", "fulfill", "noentry"] message: Optional[str] redirectURL: Optional[str] status: Optional[int] headers: Optional["HeadersArray"] body: Optional[str] class TimeoutSettings: def __init__(self, parent: Optional["TimeoutSettings"]) -> None: self._parent = parent self._default_timeout: Optional[float] = None self._default_navigation_timeout: Optional[float] = None def set_default_timeout(self, timeout: Optional[float]) -> None: self._default_timeout = timeout def timeout(self, timeout: float = None) -> float: if timeout is not None: return timeout if self._default_timeout is not None: return self._default_timeout if self._parent: return self._parent.timeout() return 30000 def set_default_navigation_timeout( self, navigation_timeout: Optional[float] ) -> None: self._default_navigation_timeout = navigation_timeout def default_navigation_timeout(self) -> Optional[float]: return self._default_navigation_timeout def default_timeout(self) -> Optional[float]: return self._default_timeout def navigation_timeout(self) -> float: if self._default_navigation_timeout is not None: return self._default_navigation_timeout if self._parent: return self._parent.navigation_timeout() return 30000 def serialize_error(ex: Exception, tb: Optional[TracebackType]) -> ErrorPayload: return ErrorPayload( message=str(ex), name="Error", stack="".join(traceback.format_tb(tb)) ) def parse_error(error: ErrorPayload, log: Optional[str] = None) -> Error: base_error_class = Error if error.get("name") == "TimeoutError": base_error_class = TimeoutError if error.get("name") == "TargetClosedError": base_error_class = TargetClosedError if not log: log = "" exc = base_error_class(patch_error_message(error["message"]) + log) exc._name = error["name"] exc._stack = error["stack"] return exc def patch_error_message(message: str) -> str: match = re.match(r"(\w+)(: expected .*)", message) if match: message = to_snake_case(match.group(1)) + match.group(2) message = message.replace( "Pass { acceptDownloads: true }", "Pass 'accept_downloads=True'" ) return message def locals_to_params(args: Dict) -> Dict: copy = {} for key in args: if key == "self": continue if args[key] is not None: copy[key] = ( args[key] if not isinstance(args[key], Dict) else locals_to_params(args[key]) ) return copy def monotonic_time() -> int: return math.floor(time.monotonic() * 1000) class RouteHandlerInvocation: complete: "asyncio.Future" route: "Route" def __init__(self, complete: "asyncio.Future", route: "Route") -> None: self.complete = complete self.route = route class RouteHandler: def __init__( self, base_url: Optional[str], url: URLMatch, handler: RouteHandlerCallback, is_sync: bool, times: Optional[int] = None, ): self._base_url = base_url self.url = url self.handler = handler self._times = times if times else math.inf self._handled_count = 0 self._is_sync = is_sync self._ignore_exception = False self._active_invocations: Set[RouteHandlerInvocation] = set() def matches(self, request_url: str) -> bool: return url_matches(self._base_url, request_url, self.url) async def handle(self, route: "Route") -> bool: handler_invocation = RouteHandlerInvocation( asyncio.get_running_loop().create_future(), route ) self._active_invocations.add(handler_invocation) try: return await self._handle_internal(route) except Exception as e: # If the handler was stopped (without waiting for completion), we ignore all exceptions. if self._ignore_exception: return False if is_target_closed_error(e): # We are failing in the handler because the target has closed. # Give user a hint! optional_async_prefix = "await " if not self._is_sync else "" raise rewrite_error( e, f"\"{str(e)}\" while running route callback.\nConsider awaiting `{optional_async_prefix}page.unroute_all(behavior='ignoreErrors')`\nbefore the end of the test to ignore remaining routes in flight.", ) raise e finally: handler_invocation.complete.set_result(None) self._active_invocations.remove(handler_invocation) async def _handle_internal(self, route: "Route") -> bool: handled_future = route._start_handling() self._handled_count += 1 if self._is_sync: handler_finished_future = route._loop.create_future() def _handler() -> None: try: self.handler(route, route.request) # type: ignore handler_finished_future.set_result(None) except Exception as e: handler_finished_future.set_exception(e) # As with event handlers, each route handler is a potentially blocking context # so it needs a fiber. g = RouteGreenlet(_handler) g.switch() await handler_finished_future else: coro_or_future = self.handler(route, route.request) # type: ignore if coro_or_future: # separate task so that we get a proper stack trace for exceptions / tracing api_name extraction await asyncio.ensure_future(coro_or_future) return await handled_future async def stop(self, behavior: Literal["ignoreErrors", "wait"]) -> None: # When a handler is manually unrouted or its page/context is closed we either # - wait for the current handler invocations to finish # - or do not wait, if the user opted out of it, but swallow all exceptions # that happen after the unroute/close. if behavior == "ignoreErrors": self._ignore_exception = True else: tasks = [] for activation in self._active_invocations: if not activation.route._did_throw: tasks.append(activation.complete) await asyncio.gather(*tasks) @property def will_expire(self) -> bool: return self._handled_count + 1 >= self._times @staticmethod def prepare_interception_patterns( handlers: List["RouteHandler"], ) -> List[Dict[str, str]]: patterns = [] all = False for handler in handlers: if isinstance(handler.url, str): patterns.append({"glob": handler.url}) elif isinstance(handler.url, re.Pattern): patterns.append( { "regexSource": handler.url.pattern, "regexFlags": escape_regex_flags(handler.url), } ) else: all = True if all: return [{"glob": "**/*"}] return patterns to_snake_case_regex = re.compile("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))") def to_snake_case(name: str) -> str: return to_snake_case_regex.sub(r"_\1", name).lower() def make_dirs_for_file(path: Union[Path, str]) -> None: if not os.path.isabs(path): path = Path.cwd() / path os.makedirs(os.path.dirname(path), exist_ok=True) async def async_writefile(file: Union[str, Path], data: Union[str, bytes]) -> None: def inner() -> None: with open(file, "w" if isinstance(data, str) else "wb") as fh: fh.write(data) loop = asyncio.get_running_loop() await loop.run_in_executor(None, inner) async def async_readfile(file: Union[str, Path]) -> bytes: def inner() -> bytes: with open(file, "rb") as fh: return fh.read() loop = asyncio.get_running_loop() return await loop.run_in_executor(None, inner) T = TypeVar("T") def to_impl(obj: T) -> T: if hasattr(obj, "_impl_obj"): return cast(Any, obj)._impl_obj return obj def object_to_array(obj: Optional[Dict]) -> Optional[List[NameValue]]: if not obj: return None result = [] for key, value in obj.items(): result.append(NameValue(name=key, value=str(value))) return result def is_file_payload(value: Optional[Any]) -> bool: return ( isinstance(value, dict) and "name" in value and "mimeType" in value and "buffer" in value ) TEXTUAL_MIME_TYPE = re.compile( r"^(text\/.*?|application\/(json|(x-)?javascript|xml.*?|ecmascript|graphql|x-www-form-urlencoded)|image\/svg(\+xml)?|application\/.*?(\+json|\+xml))(;\s*charset=.*)?$" ) def is_textual_mime_type(mime_type: str) -> bool: return bool(TEXTUAL_MIME_TYPE.match(mime_type))