from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.web import Application, RequestHandler
from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError


class EchoHandler(WebSocketHandler):
    def initialize(self, close_future):
        self.close_future = close_future

    def on_message(self, message):
        self.write_message(message, isinstance(message, bytes))

    def on_close(self):
        self.close_future.set_result(None)


class NonWebSocketHandler(RequestHandler):
    def get(self):
        self.write('ok')


class WebSocketTest(AsyncHTTPTestCase):
    def get_app(self):
        self.close_future = Future()
        return Application([
            ('/echo', EchoHandler, dict(close_future=self.close_future)),
            ('/non_ws', NonWebSocketHandler),
        ])

    @gen_test
    def test_websocket_gen(self):
        ws = yield websocket_connect(
            'ws://localhost:%d/echo' % self.get_http_port(),
            io_loop=self.io_loop)
        ws.write_message('hello')
        response = yield ws.read_message()
        self.assertEqual(response, 'hello')

    def test_websocket_callbacks(self):
        websocket_connect(
            'ws://localhost:%d/echo' % self.get_http_port(),
            io_loop=self.io_loop, callback=self.stop)
        ws = self.wait().result()
        ws.write_message('hello')
        ws.read_message(self.stop)
        response = self.wait().result()
        self.assertEqual(response, 'hello')

    @gen_test
    def test_websocket_http_fail(self):
        with self.assertRaises(HTTPError) as cm:
            yield websocket_connect(
                'ws://localhost:%d/notfound' % self.get_http_port(),
                io_loop=self.io_loop)
        self.assertEqual(cm.exception.code, 404)

    @gen_test
    def test_websocket_http_success(self):
        with self.assertRaises(WebSocketError):
            yield websocket_connect(
                'ws://localhost:%d/non_ws' % self.get_http_port(),
                io_loop=self.io_loop)

    @gen_test
    def test_websocket_network_fail(self):
        sock, port = bind_unused_port()
        sock.close()
        with self.assertRaises(HTTPError) as cm:
            with ExpectLog(gen_log, ".*"):
                yield websocket_connect(
                    'ws://localhost:%d/' % port,
                    io_loop=self.io_loop,
                    connect_timeout=0.01)
        self.assertEqual(cm.exception.code, 599)

    @gen_test
    def test_websocket_close_buffered_data(self):
        ws = yield websocket_connect(
            'ws://localhost:%d/echo' % self.get_http_port())
        ws.write_message('hello')
        ws.write_message('world')
        ws.stream.close()
        yield self.close_future