import MySQLdb import os import socket import time from functools import partial from django.conf import settings as djsettings from django.core.management.base import CommandError from django.core.urlresolvers import reverse from selenium.webdriver.common.action_chains import ActionChains from selenium.webdriver.common.keys import Keys from selenium.webdriver.support.select import Select from orchestra.admin.utils import change_url from orchestra.apps.accounts.models import Account from orchestra.apps.orchestration.models import Server, Route from orchestra.utils.system import sshrun from orchestra.utils.tests import (BaseLiveServerTestCase, random_ascii, save_response_on_error, snapshot_on_error) from ... import backends, settings from ...models import Database, DatabaseUser class DatabaseTestMixin(object): MASTER_SERVER = os.environ.get('ORCHESTRA_SECOND_SERVER', 'localhost') DEPENDENCIES = ( 'orchestra.apps.orchestration', 'orcgestra.apps.databases', ) def setUp(self): super(DatabaseTestMixin, self).setUp() self.add_route() djsettings.DEBUG = True def add_route(self): raise NotImplementedError def save(self): raise NotImplementedError def add(self): raise NotImplementedError def delete(self): raise NotImplementedError def update(self): raise NotImplementedError def disable(self): raise NotImplementedError def add_group(self, username, groupname): raise NotImplementedError def test_add(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.validate_create_table(dbname, username, password) def test_delete(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.validate_create_table(dbname, username, password) self.delete(dbname) self.delete_user(username) self.validate_delete(dbname, username, password) self.validate_delete_user(dbname, username) def test_change_password(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.addCleanup(self.delete, dbname) self.addCleanup(self.delete_user, username) self.validate_create_table(dbname, username, password) new_password = '@!?%spppP001' % random_ascii(5) self.change_password(username, new_password) self.validate_login_error(dbname, username, password) self.validate_create_table(dbname, username, new_password) def test_add_user(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.addCleanup(self.delete, dbname) self.addCleanup(self.delete_user, username) self.validate_create_table(dbname, username, password) username2 = '%s_dbuser' % random_ascii(5) password2 = '@!?%spppP001' % random_ascii(5) self.add_user(username2, password2) self.addCleanup(self.delete_user, username2) self.validate_login_error(dbname, username2, password2) self.add_user_to_db(username2, dbname) self.validate_create_table(dbname, username, password) self.validate_create_table(dbname, username2, password2) def test_delete_user(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.addCleanup(self.delete, dbname) self.validate_create_table(dbname, username, password) username2 = '%s_dbuser' % random_ascii(5) password2 = '@!?%spppP001' % random_ascii(5) self.add_user(username2, password2) self.add_user_to_db(username2, dbname) self.delete_user(username) self.validate_login_error(dbname, username, password) self.validate_create_table(dbname, username2, password2) self.delete_user(username2) self.validate_login_error(dbname, username2, password2) class MySQLBackendMixin(object): db_type = 'mysql' def setUp(self): super(MySQLBackendMixin, self).setUp() # Get local ip address used to reach self.MASTER_SERVER s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect((self.MASTER_SERVER, 22)) settings.DATABASES_DEFAULT_HOST = s.getsockname()[0] s.close() def add_route(self): server = Server.objects.create(name=self.MASTER_SERVER) backend = backends.MySQLBackend.get_name() match = "database.type == '%s'" % self.db_type Route.objects.create(backend=backend, match=match, host=server) match = "databaseuser.type == '%s'" % self.db_type backend = backends.MySQLUserBackend.get_name() Route.objects.create(backend=backend, match=match, host=server) def validate_create_table(self, name, username, password): db = MySQLdb.connect(host=self.MASTER_SERVER, port=3306, user=username, passwd=password, db=name) cur = db.cursor() cur.execute('CREATE TABLE table_%s ( id INT ) ;' % random_ascii(10)) def validate_login_error(self, dbname, username, password): self.assertRaises(MySQLdb.OperationalError, self.validate_create_table, dbname, username, password ) def validate_delete(self, name, username, password): self.assertRaises(MySQLdb.OperationalError, self.validate_create_table, name, username, password ) def validate_delete_user(self, name, username): context = { 'name': name, 'username': username, } self.assertEqual('', sshrun(self.MASTER_SERVER, """mysql mysql -e 'SELECT * FROM db WHERE db="%(name)s";'""" % context, display=False).stdout) self.assertEqual('', sshrun(self.MASTER_SERVER, """mysql mysql -e 'SELECT * FROM user WHERE user="%(username)s";'""" % context, display=False).stdout) # TODO remove used from database class RESTDatabaseMixin(DatabaseTestMixin): def setUp(self): super(RESTDatabaseMixin, self).setUp() self.rest_login() @save_response_on_error def add(self, dbname, username, password): user = self.rest.databaseusers.create(username=username, password=password, type=self.db_type) users = [{ 'username': user.username }] self.rest.databases.create(name=dbname, users=users, type=self.db_type) @save_response_on_error def delete(self, dbname): self.rest.databases.retrieve(name=dbname).delete() @save_response_on_error def change_password(self, username, password): user = self.rest.databaseusers.retrieve(username=username).get() user.set_password(password) @save_response_on_error def add_user(self, username, password): self.rest.databaseusers.create(username=username, password=password, type=self.db_type) @save_response_on_error def add_user_to_db(self, username, dbname): user = self.rest.databaseusers.retrieve(username=username).get() db = self.rest.databases.retrieve(name=dbname).get() db.users.append(user) db.save() @save_response_on_error def delete_user(self, username): self.rest.databaseusers.retrieve(username=username).delete() class AdminDatabaseMixin(DatabaseTestMixin): def setUp(self): super(AdminDatabaseMixin, self).setUp() self.admin_login() @snapshot_on_error def add(self, dbname, username, password): url = self.live_server_url + reverse('admin:databases_database_add') self.selenium.get(url) type_input = self.selenium.find_element_by_id('id_type') type_select = Select(type_input) type_select.select_by_value(self.db_type) name_field = self.selenium.find_element_by_id('id_name') name_field.send_keys(dbname) username_field = self.selenium.find_element_by_id('id_username') username_field.send_keys(username) password_field = self.selenium.find_element_by_id('id_password1') password_field.send_keys(password) password_field = self.selenium.find_element_by_id('id_password2') password_field.send_keys(password) name_field.submit() self.assertNotEqual(url, self.selenium.current_url) @snapshot_on_error def delete(self, dbname): db = Database.objects.get(name=dbname) self.admin_delete(db) @snapshot_on_error def change_password(self, username, password): user = DatabaseUser.objects.get(username=username) self.admin_change_password(user, password) @snapshot_on_error def add_user(self, username, password): url = self.live_server_url + reverse('admin:databases_databaseuser_add') self.selenium.get(url) type_input = self.selenium.find_element_by_id('id_type') type_select = Select(type_input) type_select.select_by_value(self.db_type) username_field = self.selenium.find_element_by_id('id_username') username_field.send_keys(username) password_field = self.selenium.find_element_by_id('id_password1') password_field.send_keys(password) password_field = self.selenium.find_element_by_id('id_password2') password_field.send_keys(password) username_field.submit() self.assertNotEqual(url, self.selenium.current_url) @snapshot_on_error def add_user_to_db(self, username, dbname): database = Database.objects.get(name=dbname, type=self.db_type) url = self.live_server_url + change_url(database) self.selenium.get(url) user = DatabaseUser.objects.get(username=username, type=self.db_type) users_input = self.selenium.find_element_by_id('id_users') users_select = Select(users_input) users_select.select_by_value(str(user.pk)) save = self.selenium.find_element_by_name('_save') save.submit() self.assertNotEqual(url, self.selenium.current_url) @snapshot_on_error def delete_user(self, username): user = DatabaseUser.objects.get(username=username) self.admin_delete(user) class RESTMysqlDatabaseTest(MySQLBackendMixin, RESTDatabaseMixin, BaseLiveServerTestCase): pass class AdminMysqlDatabaseTest(MySQLBackendMixin, AdminDatabaseMixin, BaseLiveServerTestCase): pass