// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "extensions/browser/api/sockets_tcp/sockets_tcp_api.h"

#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>

#include "base/bind.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/browser_thread.h"
#include "content/public/browser/storage_partition.h"
#include "content/public/common/socket_permission_request.h"
#include "extensions/browser/api/socket/tcp_socket.h"
#include "extensions/browser/api/socket/tls_socket.h"
#include "extensions/browser/api/sockets_tcp/tcp_socket_event_dispatcher.h"
#include "extensions/common/api/sockets/sockets_manifest_data.h"
#include "net/base/net_errors.h"

using extensions::ResumableTCPSocket;
using extensions::api::sockets_tcp::SocketInfo;
using extensions::api::sockets_tcp::SocketProperties;

namespace {

const char kSocketNotFoundError[] = "Socket not found";
const char kPermissionError[] = "Does not have permission";
const char kInvalidSocketStateError[] =
    "Socket must be a connected client TCP socket.";
const char kSocketNotConnectedError[] = "Socket not connected";

SocketInfo CreateSocketInfo(int socket_id, ResumableTCPSocket* socket) {
  SocketInfo socket_info;
  // This represents what we know about the socket, and does not call through
  // to the system.
  socket_info.socket_id = socket_id;
  if (!socket->name().empty()) {
    socket_info.name = std::make_unique<std::string>(socket->name());
  }
  socket_info.persistent = socket->persistent();
  if (socket->buffer_size() > 0) {
    socket_info.buffer_size = std::make_unique<int>(socket->buffer_size());
  }
  socket_info.paused = socket->paused();
  socket_info.connected = socket->IsConnected();

  // Grab the local address as known by the OS.
  net::IPEndPoint localAddress;
  if (socket->GetLocalAddress(&localAddress)) {
    socket_info.local_address =
        std::make_unique<std::string>(localAddress.ToStringWithoutPort());
    socket_info.local_port = std::make_unique<int>(localAddress.port());
  }

  // Grab the peer address as known by the OS. This and the call below will
  // always succeed while the socket is connected, even if the socket has
  // been remotely closed by the peer; only reading the socket will reveal
  // that it should be closed locally.
  net::IPEndPoint peerAddress;
  if (socket->GetPeerAddress(&peerAddress)) {
    socket_info.peer_address =
        std::make_unique<std::string>(peerAddress.ToStringWithoutPort());
    socket_info.peer_port = std::make_unique<int>(peerAddress.port());
  }

  return socket_info;
}

void SetSocketProperties(ResumableTCPSocket* socket,
                         SocketProperties* properties) {
  if (properties->name.get()) {
    socket->set_name(*properties->name);
  }
  if (properties->persistent.get()) {
    socket->set_persistent(*properties->persistent);
  }
  if (properties->buffer_size.get()) {
    // buffer size is validated when issuing the actual Recv operation
    // on the socket.
    socket->set_buffer_size(*properties->buffer_size);
  }
}

}  // namespace

namespace extensions {
namespace api {

using content::SocketPermissionRequest;

TCPSocketApiFunction::~TCPSocketApiFunction() = default;

std::unique_ptr<SocketResourceManagerInterface>
TCPSocketApiFunction::CreateSocketResourceManager() {
  return std::unique_ptr<SocketResourceManagerInterface>(
      new SocketResourceManager<ResumableTCPSocket>());
}

ResumableTCPSocket* TCPSocketApiFunction::GetTcpSocket(int socket_id) {
  return static_cast<ResumableTCPSocket*>(GetSocket(socket_id));
}

TCPSocketExtensionWithDnsLookupFunction::
    ~TCPSocketExtensionWithDnsLookupFunction() = default;

std::unique_ptr<SocketResourceManagerInterface>
TCPSocketExtensionWithDnsLookupFunction::CreateSocketResourceManager() {
  return std::unique_ptr<SocketResourceManagerInterface>(
      new SocketResourceManager<ResumableTCPSocket>());
}

ResumableTCPSocket* TCPSocketExtensionWithDnsLookupFunction::GetTcpSocket(
    int socket_id) {
  return static_cast<ResumableTCPSocket*>(GetSocket(socket_id));
}

SocketsTcpCreateFunction::SocketsTcpCreateFunction() = default;

SocketsTcpCreateFunction::~SocketsTcpCreateFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpCreateFunction::Work() {
  std::unique_ptr<sockets_tcp::Create::Params> params =
      sockets_tcp::Create::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params.get());

  ResumableTCPSocket* socket =
      new ResumableTCPSocket(browser_context(), extension_id());

  sockets_tcp::SocketProperties* properties = params->properties.get();
  if (properties) {
    SetSocketProperties(socket, properties);
  }

  sockets_tcp::CreateInfo create_info;
  create_info.socket_id = AddSocket(socket);
  return RespondNow(
      ArgumentList(sockets_tcp::Create::Results::Create(create_info)));
}

SocketsTcpUpdateFunction::SocketsTcpUpdateFunction() = default;

SocketsTcpUpdateFunction::~SocketsTcpUpdateFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpUpdateFunction::Work() {
  std::unique_ptr<sockets_tcp::Update::Params> params =
      sockets_tcp::Update::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params.get());

  ResumableTCPSocket* socket = GetTcpSocket(params->socket_id);
  if (!socket) {
    return RespondNow(Error(kSocketNotFoundError));
  }

  SetSocketProperties(socket, &params->properties);
  return RespondNow(NoArguments());
}

SocketsTcpSetPausedFunction::SocketsTcpSetPausedFunction() = default;

SocketsTcpSetPausedFunction::~SocketsTcpSetPausedFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpSetPausedFunction::Work() {
  std::unique_ptr<sockets_tcp::SetPaused::Params> params =
      api::sockets_tcp::SetPaused::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params.get());

  TCPSocketEventDispatcher* socket_event_dispatcher =
      TCPSocketEventDispatcher::Get(browser_context());
  DCHECK(socket_event_dispatcher)
      << "There is no socket event dispatcher. "
         "If this assertion is failing during a test, then it is likely that "
         "TestExtensionSystem is failing to provide an instance of "
         "TCPSocketEventDispatcher.";

  ResumableTCPSocket* socket = GetTcpSocket(params->socket_id);
  if (!socket) {
    return RespondNow(Error(kSocketNotFoundError));
  }

  if (socket->paused() != params->paused) {
    socket->set_paused(params->paused);
    if (socket->IsConnected() && !params->paused) {
      socket_event_dispatcher->OnSocketResume(extension_id(),
                                              params->socket_id);
    }
  }

  return RespondNow(NoArguments());
}

SocketsTcpSetKeepAliveFunction::SocketsTcpSetKeepAliveFunction() = default;

SocketsTcpSetKeepAliveFunction::~SocketsTcpSetKeepAliveFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpSetKeepAliveFunction::Work() {
  std::unique_ptr<sockets_tcp::SetKeepAlive::Params> params =
      api::sockets_tcp::SetKeepAlive::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params.get());

  ResumableTCPSocket* socket = GetTcpSocket(params->socket_id);
  if (!socket) {
    return RespondNow(ErrorWithCode(net::ERR_FAILED, kSocketNotFoundError));
  }

  int delay = params->delay ? *params->delay : 0;

  socket->SetKeepAlive(
      params->enable, delay,
      base::BindOnce(&SocketsTcpSetKeepAliveFunction::OnCompleted, this));
  return RespondLater();
}

void SocketsTcpSetKeepAliveFunction::OnCompleted(bool success) {
  if (success) {
    Respond(OneArgument(base::Value(net::OK)));
  } else {
    Respond(
        ErrorWithCode(net::ERR_FAILED, net::ErrorToString(net::ERR_FAILED)));
  }
}

SocketsTcpSetNoDelayFunction::SocketsTcpSetNoDelayFunction() = default;

SocketsTcpSetNoDelayFunction::~SocketsTcpSetNoDelayFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpSetNoDelayFunction::Work() {
  std::unique_ptr<sockets_tcp::SetNoDelay::Params> params =
      api::sockets_tcp::SetNoDelay::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params.get());

  ResumableTCPSocket* socket = GetTcpSocket(params->socket_id);
  if (!socket) {
    return RespondNow(ErrorWithCode(net::ERR_FAILED, kSocketNotFoundError));
  }
  socket->SetNoDelay(
      params->no_delay,
      base::BindOnce(&SocketsTcpSetNoDelayFunction::OnCompleted, this));
  return RespondLater();
}

void SocketsTcpSetNoDelayFunction::OnCompleted(bool success) {
  if (success) {
    Respond(OneArgument(base::Value(net::OK)));
  } else {
    Respond(
        ErrorWithCode(net::ERR_FAILED, net::ErrorToString(net::ERR_FAILED)));
  }
}

SocketsTcpConnectFunction::SocketsTcpConnectFunction() = default;

SocketsTcpConnectFunction::~SocketsTcpConnectFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpConnectFunction::Work() {
  params_ = sockets_tcp::Connect::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params_.get());

  socket_event_dispatcher_ = TCPSocketEventDispatcher::Get(browser_context());
  DCHECK(socket_event_dispatcher_)
      << "There is no socket event dispatcher. "
         "If this assertion is failing during a test, then it is likely that "
         "TestExtensionSystem is failing to provide an instance of "
         "TCPSocketEventDispatcher.";

  ResumableTCPSocket* socket = GetTcpSocket(params_->socket_id);
  if (!socket) {
    return RespondNow(Error(kSocketNotFoundError));
  }

  socket->set_hostname(params_->peer_address);

  content::SocketPermissionRequest param(SocketPermissionRequest::TCP_CONNECT,
                                         params_->peer_address,
                                         params_->peer_port);
  if (!SocketsManifestData::CheckRequest(extension(), param)) {
    return RespondNow(Error(kPermissionError));
  }

  StartDnsLookup(net::HostPortPair(params_->peer_address, params_->peer_port));
  return RespondLater();
}

void SocketsTcpConnectFunction::AfterDnsLookup(int lookup_result) {
  if (lookup_result == net::OK) {
    StartConnect();
  } else {
    OnCompleted(lookup_result);
  }
}

void SocketsTcpConnectFunction::StartConnect() {
  ResumableTCPSocket* socket = GetTcpSocket(params_->socket_id);
  if (!socket) {
    Respond(Error(kSocketNotFoundError));
    return;
  }

  socket->Connect(
      addresses_,
      base::BindOnce(&SocketsTcpConnectFunction::OnCompleted, this));
}

void SocketsTcpConnectFunction::OnCompleted(int net_result) {
  if (net_result == net::OK) {
    socket_event_dispatcher_->OnSocketConnect(extension_id(),
                                              params_->socket_id);
  }

  if (net_result == net::OK) {
    Respond(OneArgument(base::Value(net_result)));
  } else {
    Respond(ErrorWithCode(net_result, net::ErrorToString(net_result)));
  }
}

SocketsTcpDisconnectFunction::SocketsTcpDisconnectFunction() = default;

SocketsTcpDisconnectFunction::~SocketsTcpDisconnectFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpDisconnectFunction::Work() {
  std::unique_ptr<sockets_tcp::Disconnect::Params> params =
      sockets_tcp::Disconnect::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params.get());

  ResumableTCPSocket* socket = GetTcpSocket(params->socket_id);
  if (!socket) {
    return RespondNow(Error(kSocketNotFoundError));
  }

  socket->Disconnect(false /* socket_destroying */);
  return RespondNow(NoArguments());
}

SocketsTcpSendFunction::SocketsTcpSendFunction() = default;

SocketsTcpSendFunction::~SocketsTcpSendFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpSendFunction::Work() {
  std::unique_ptr<sockets_tcp::Send::Params> params =
      sockets_tcp::Send::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params.get());
  size_t io_buffer_size = params->data.size();

  scoped_refptr<net::IOBuffer> io_buffer =
      base::MakeRefCounted<net::IOBuffer>(params->data.size());
  base::ranges::copy(params->data, io_buffer->data());

  ResumableTCPSocket* socket = GetTcpSocket(params->socket_id);
  if (!socket) {
    return RespondNow(Error(kSocketNotFoundError));
  }

  socket->Write(io_buffer, io_buffer_size,
                base::BindOnce(&SocketsTcpSendFunction::OnCompleted, this));
  return RespondLater();
}

void SocketsTcpSendFunction::OnCompleted(int net_result) {
  if (net_result >= net::OK) {
    SetSendResult(net::OK, net_result);
  } else {
    SetSendResult(net_result, -1);
  }
}

void SocketsTcpSendFunction::SetSendResult(int net_result, int bytes_sent) {
  CHECK(net_result <= net::OK) << "Network status code must be <= net::OK";

  sockets_tcp::SendInfo send_info;
  send_info.result_code = net_result;
  if (net_result == net::OK) {
    send_info.bytes_sent = std::make_unique<int>(bytes_sent);
  }

  auto args = sockets_tcp::Send::Results::Create(send_info);
  if (net_result == net::OK) {
    Respond(ArgumentList(std::move(args)));
  } else {
    Respond(
        ErrorWithArguments(std::move(args), net::ErrorToString(net_result)));
  }
}

SocketsTcpCloseFunction::SocketsTcpCloseFunction() = default;

SocketsTcpCloseFunction::~SocketsTcpCloseFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpCloseFunction::Work() {
  std::unique_ptr<sockets_tcp::Close::Params> params =
      sockets_tcp::Close::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params.get());

  ResumableTCPSocket* socket = GetTcpSocket(params->socket_id);
  if (!socket) {
    return RespondNow(Error(kSocketNotFoundError));
  }

  RemoveSocket(params->socket_id);
  return RespondNow(NoArguments());
}

SocketsTcpGetInfoFunction::SocketsTcpGetInfoFunction() = default;

SocketsTcpGetInfoFunction::~SocketsTcpGetInfoFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpGetInfoFunction::Work() {
  std::unique_ptr<sockets_tcp::GetInfo::Params> params =
      sockets_tcp::GetInfo::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params.get());

  ResumableTCPSocket* socket = GetTcpSocket(params->socket_id);
  if (!socket) {
    return RespondNow(Error(kSocketNotFoundError));
  }

  sockets_tcp::SocketInfo socket_info =
      CreateSocketInfo(params->socket_id, socket);
  return RespondNow(
      ArgumentList(sockets_tcp::GetInfo::Results::Create(socket_info)));
}

SocketsTcpGetSocketsFunction::SocketsTcpGetSocketsFunction() = default;

SocketsTcpGetSocketsFunction::~SocketsTcpGetSocketsFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpGetSocketsFunction::Work() {
  std::vector<sockets_tcp::SocketInfo> socket_infos;
  std::unordered_set<int>* resource_ids = GetSocketIds();
  if (resource_ids) {
    for (int socket_id : *resource_ids) {
      ResumableTCPSocket* socket = GetTcpSocket(socket_id);
      if (socket) {
        socket_infos.push_back(CreateSocketInfo(socket_id, socket));
      }
    }
  }
  return RespondNow(
      ArgumentList(sockets_tcp::GetSockets::Results::Create(socket_infos)));
}

SocketsTcpSecureFunction::SocketsTcpSecureFunction() = default;

SocketsTcpSecureFunction::~SocketsTcpSecureFunction() = default;

ExtensionFunction::ResponseAction SocketsTcpSecureFunction::Work() {
  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
  params_ = api::sockets_tcp::Secure::Params::Create(*args_);
  EXTENSION_FUNCTION_VALIDATE(params_.get());

  ResumableTCPSocket* socket = GetTcpSocket(params_->socket_id);
  if (!socket) {
    return RespondNow(
        ErrorWithCode(net::ERR_INVALID_ARGUMENT, kSocketNotFoundError));
  }

  paused_ = socket->paused();
  persistent_ = socket->persistent();

  // Make sure it's a connected TCP client socket. Error out if it's already
  // secure()'d.
  if (socket->GetSocketType() != Socket::TYPE_TCP) {
    return RespondNow(
        ErrorWithCode(net::ERR_INVALID_ARGUMENT, kInvalidSocketStateError));
  }

  if (!socket->IsConnected()) {
    return RespondNow(
        ErrorWithCode(net::ERR_INVALID_ARGUMENT, kSocketNotConnectedError));
  }

  // UpgradeSocketToTLS() uses the older API's SecureOptions. Copy over the
  // only values inside -- TLSVersionConstraints's |min| and |max|,
  api::socket::SecureOptions legacy_params;
  if (params_->options.get() && params_->options->tls_version.get()) {
    legacy_params.tls_version =
        std::make_unique<api::socket::TLSVersionConstraints>();
    if (params_->options->tls_version->min.get()) {
      legacy_params.tls_version->min =
          std::make_unique<std::string>(*params_->options->tls_version->min);
    }
    if (params_->options->tls_version->max.get()) {
      legacy_params.tls_version->max =
          std::make_unique<std::string>(*params_->options->tls_version->max);
    }
  }

  socket->UpgradeToTLS(
      &legacy_params,
      base::BindOnce(&SocketsTcpSecureFunction::TlsConnectDone, this));
  return RespondLater();
}

void SocketsTcpSecureFunction::TlsConnectDone(
    int result,
    mojo::PendingRemote<network::mojom::TLSClientSocket> tls_socket,
    const net::IPEndPoint& local_addr,
    const net::IPEndPoint& peer_addr,
    mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
    mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
  if (result != net::OK) {
    RemoveSocket(params_->socket_id);
    Respond(ErrorWithCode(result, net::ErrorToString(result)));
    return;
  }
  auto socket =
      std::make_unique<TLSSocket>(std::move(tls_socket), local_addr, peer_addr,
                                  std::move(receive_pipe_handle),
                                  std::move(send_pipe_handle), extension_id());
  socket->set_persistent(persistent_);
  socket->set_paused(paused_);
  ReplaceSocket(params_->socket_id, socket.release());
  Respond(OneArgument(base::Value(result)));
}

}  // namespace api
}  // namespace extensions
