@ -12,6 +12,8 @@ from . import util
log = logging . getLogger ( __name__ )
UNIX_SOCKET_TIMEOUT = 0.1
def remove_file ( path , remove = os . remove , exists = os . path . exists ) :
try :
@ -44,19 +46,31 @@ def handle_connection(conn, handler):
util . send ( conn , reply )
except EOFError :
log . debug ( ' goodbye agent ' )
except :
log . exception ( ' error ' )
raise
def server_thread ( server , handler ) :
def retry ( func , exception_type , quit_event ) :
while True :
if quit_event . is_set ( ) :
raise StopIteration
try :
return func ( )
except exception_type :
pass
def server_thread ( server , handler , quit_event ) :
log . debug ( ' server thread started ' )
def accept_connection ( ) :
conn , _ = server . accept ( )
return conn
while True :
log . debug ( ' waiting for connection on %s ' , server . getsockname ( ) )
try :
conn , _ = server . accept ( )
except socket . error as e :
log . debug ( ' server stopped: %s ' , e )
conn = retry ( accept_connection , socket . timeout , quit_event )
except StopIteration :
log . debug ( ' server stopped ' )
break
with contextlib . closing ( conn ) :
handle_connection ( conn , handler )
@ -64,7 +78,7 @@ def server_thread(server, handler):
@contextlib.contextmanager
def spawn ( func , * * kwargs ) :
def spawn ( func , kwargs ) :
t = threading . Thread ( target = func , kwargs = kwargs )
t . start ( )
yield
@ -72,20 +86,23 @@ def spawn(func, **kwargs):
@contextlib.contextmanager
def serve ( public_keys , signer , sock_path = None ):
def serve ( public_keys , signer , sock_path = None , timeout = UNIX_SOCKET_TIMEOUT ):
if sock_path is None :
sock_path = tempfile . mktemp ( prefix = ' ssh-agent- ' )
keys = [ formats . import_public_key ( k ) for k in public_keys ]
environ = { ' SSH_AUTH_SOCK ' : sock_path , ' SSH_AGENT_PID ' : str ( os . getpid ( ) ) }
with unix_domain_socket_server ( sock_path ) as server :
server . settimeout ( timeout )
handler = protocol . Handler ( keys = keys , signer = signer )
with spawn ( server_thread , server = server , handler = handler ) :
quit_event = threading . Event ( )
kwargs = dict ( server = server , handler = handler , quit_event = quit_event )
with spawn ( server_thread , kwargs ) :
try :
yield environ
finally :
log . debug ( ' closing server ' )
server. shutdown ( socket . SHUT_RD )
quit_event. set ( )
def run_process ( command , environ , use_shell = False ) :