diff --git a/README.md b/README.md index 218fa765..3df9a16f 100644 --- a/README.md +++ b/README.md @@ -326,6 +326,7 @@ Optionally, you can use the following command-line flags: |---------------------------------------|-------------| | `--api` | Enable the API extension. | | `--public-api` | Create a public URL for the API using Cloudfare. | +| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. | | `--api-blocking-port BLOCKING_PORT` | The listening port for the blocking API. | | `--api-streaming-port STREAMING_PORT` | The listening port for the streaming API. | diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 46b27580..ce29f33b 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -23,6 +23,7 @@ services: - ./prompts:/app/prompts - ./softprompts:/app/softprompts - ./training:/app/training + - ./cloudflared:/etc/cloudflared deploy: resources: reservations: diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index fbbc5ec1..6b28205a 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -200,7 +200,7 @@ class Handler(BaseHTTPRequestHandler): super().end_headers() -def _run_server(port: int, share: bool = False): +def _run_server(port: int, share: bool = False, tunnel_id=str): address = '0.0.0.0' if shared.args.listen else '127.0.0.1' server = ThreadingHTTPServer((address, port), Handler) @@ -210,7 +210,7 @@ def _run_server(port: int, share: bool = False): if share: try: - try_start_cloudflared(port, max_attempts=3, on_start=on_start) + try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start) except Exception: pass else: @@ -220,5 +220,5 @@ def _run_server(port: int, share: bool = False): server.serve_forever() -def start_server(port: int, share: bool = False): - Thread(target=_run_server, args=[port, share], daemon=True).start() +def start_server(port: int, share: bool = False, tunnel_id=str): + Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start() diff --git a/extensions/api/requirements.txt b/extensions/api/requirements.txt index 14e29d35..e4f26c3a 100644 --- a/extensions/api/requirements.txt +++ b/extensions/api/requirements.txt @@ -1,2 +1,2 @@ -flask_cloudflared==0.0.12 +flask_cloudflared==0.0.14 websockets==11.0.2 \ No newline at end of file diff --git a/extensions/api/script.py b/extensions/api/script.py index 5d1b1a68..80617b3e 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -4,5 +4,5 @@ from modules import shared def setup(): - blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api) - streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api) + blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id) + streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index 6afa827d..9175eeb0 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -102,7 +102,7 @@ async def _run(host: str, port: int): await asyncio.Future() # run forever -def _run_server(port: int, share: bool = False): +def _run_server(port: int, share: bool = False, tunnel_id=str): address = '0.0.0.0' if shared.args.listen else '127.0.0.1' def on_start(public_url: str): @@ -111,7 +111,7 @@ def _run_server(port: int, share: bool = False): if share: try: - try_start_cloudflared(port, max_attempts=3, on_start=on_start) + try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start) except Exception as e: print(e) else: @@ -120,5 +120,5 @@ def _run_server(port: int, share: bool = False): asyncio.run(_run(host=address, port=port)) -def start_server(port: int, share: bool = False): - Thread(target=_run_server, args=[port, share], daemon=True).start() +def start_server(port: int, share: bool = False, tunnel_id=str): + Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start() diff --git a/extensions/api/util.py b/extensions/api/util.py index f36c070b..7ebfaa32 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -86,12 +86,12 @@ def build_parameters(body, chat=False): return generate_params -def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): +def try_start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): Thread(target=_start_cloudflared, args=[ - port, max_attempts, on_start], daemon=True).start() + port, tunnel_id, max_attempts, on_start], daemon=True).start() -def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): +def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): try: from flask_cloudflared import _run_cloudflared except ImportError: @@ -101,7 +101,7 @@ def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Call for _ in range(max_attempts): try: - public_url = _run_cloudflared(port, port + 1) + public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id) if on_start: on_start(public_url) diff --git a/modules/shared.py b/modules/shared.py index 30f6512c..05c402c4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -182,6 +182,7 @@ parser.add_argument('--api', action='store_true', help='Enable the API extension parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.') parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.') parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.') +parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) # Multimodal parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')