Abstract SSH setup to support GIT_SSH

In order to honor GIT_SSH the TransportGitSsh class needs to run the
process named by the GIT_SSH environment variable and use that as the
pipes for connectivity to the remote peer.  Refactor the current
transport code to support a different type of pipe connectivity, so we
can later add GIT_SSH.

Bug: 321062
Change-Id: I9d8ee1a95f1bac5013b33a4a42dcf1f98f92172f
Signed-off-by: Shawn O. Pearce <spearce@spearce.org>
This commit is contained in:
Shawn O. Pearce 2010-12-03 16:14:46 -08:00
parent 37001ddc8d
commit 04b289cc42
1 changed files with 130 additions and 93 deletions

View File

@ -97,12 +97,16 @@ static boolean canHandle(final URIish uri) {
@Override @Override
public FetchConnection openFetch() throws TransportException { public FetchConnection openFetch() throws TransportException {
return new SshFetchConnection(); return new SshFetchConnection(newConnection());
} }
@Override @Override
public PushConnection openPush() throws TransportException { public PushConnection openPush() throws TransportException {
return new SshPushConnection(); return new SshPushConnection(newConnection());
}
private Connection newConnection() {
return new JschConnection();
} }
private static void sqMinimal(final StringBuilder cmd, final String val) { private static void sqMinimal(final StringBuilder cmd, final String val) {
@ -128,7 +132,7 @@ private static void sq(final StringBuilder cmd, final String val) {
cmd.append(QuotedString.BOURNE.quote(val)); cmd.append(QuotedString.BOURNE.quote(val));
} }
private String commandFor(final String exe) { String commandFor(final String exe) {
String path = uri.getPath(); String path = uri.getPath();
if (uri.getScheme() != null && uri.getPath().startsWith("/~")) if (uri.getScheme() != null && uri.getPath().startsWith("/~"))
path = (uri.getPath().substring(1)); path = (uri.getPath().substring(1));
@ -146,28 +150,6 @@ private String commandFor(final String exe) {
return cmd.toString(); return cmd.toString();
} }
ChannelExec exec(final String exe) throws TransportException {
initSession();
try {
final ChannelExec channel = (ChannelExec) sock.openChannel("exec");
channel.setCommand(commandFor(exe));
return channel;
} catch (JSchException je) {
throw new TransportException(uri, je.getMessage(), je);
}
}
private void connect(ChannelExec channel) throws TransportException {
try {
channel.connect(getTimeout() > 0 ? getTimeout() * 1000 : 0);
if (!channel.isConnected())
throw new TransportException(uri, "connection failed");
} catch (JSchException e) {
throw new TransportException(uri, e.getMessage(), e);
}
}
void checkExecFailure(int status, String exe, String why) void checkExecFailure(int status, String exe, String why)
throws TransportException { throws TransportException {
if (status == 127) { if (status == 127) {
@ -198,60 +180,132 @@ NoRemoteRepositoryException cleanNotFound(NoRemoteRepositoryException nf,
return new NoRemoteRepositoryException(uri, why); return new NoRemoteRepositoryException(uri, why);
} }
// JSch won't let us interrupt writes when we use our InterruptTimer to private abstract class Connection {
// break out of a long-running write operation. To work around that we abstract void exec(String commandName) throws TransportException;
// spawn a background thread to shuttle data through a pipe, as we can
// issue an interrupted write out of that. Its slower, so we only use abstract void connect() throws TransportException;
// this route if there is a timeout.
abstract InputStream getInputStream() throws IOException;
abstract OutputStream getOutputStream() throws IOException;
abstract InputStream getErrorStream() throws IOException;
abstract int getExitStatus();
abstract void close();
}
private class JschConnection extends Connection {
private ChannelExec channel;
private int exitStatus;
@Override
void exec(String commandName) throws TransportException {
initSession();
try {
channel = (ChannelExec) sock.openChannel("exec");
channel.setCommand(commandFor(commandName));
} catch (JSchException je) {
throw new TransportException(uri, je.getMessage(), je);
}
}
@Override
void connect() throws TransportException {
try {
channel.connect(getTimeout() > 0 ? getTimeout() * 1000 : 0);
if (!channel.isConnected())
throw new TransportException(uri, "connection failed");
} catch (JSchException e) {
throw new TransportException(uri, e.getMessage(), e);
}
}
@Override
InputStream getInputStream() throws IOException {
return channel.getInputStream();
}
@Override
OutputStream getOutputStream() throws IOException {
// JSch won't let us interrupt writes when we use our InterruptTimer
// to break out of a long-running write operation. To work around
// that we spawn a background thread to shuttle data through a pipe,
// as we can issue an interrupted write out of that. Its slower, so
// we only use this route if there is a timeout.
// //
private OutputStream outputStream(ChannelExec channel) throws IOException {
final OutputStream out = channel.getOutputStream(); final OutputStream out = channel.getOutputStream();
if (getTimeout() <= 0) if (getTimeout() <= 0)
return out; return out;
final PipedInputStream pipeIn = new PipedInputStream(); final PipedInputStream pipeIn = new PipedInputStream();
final StreamCopyThread copyThread = new StreamCopyThread(pipeIn, out); final StreamCopyThread copier = new StreamCopyThread(pipeIn, out);
final PipedOutputStream pipeOut = new PipedOutputStream(pipeIn) { final PipedOutputStream pipeOut = new PipedOutputStream(pipeIn) {
@Override @Override
public void flush() throws IOException { public void flush() throws IOException {
super.flush(); super.flush();
copyThread.flush(); copier.flush();
} }
@Override @Override
public void close() throws IOException { public void close() throws IOException {
super.close(); super.close();
try { try {
copyThread.join(getTimeout() * 1000); copier.join(getTimeout() * 1000);
} catch (InterruptedException e) { } catch (InterruptedException e) {
// Just wake early, the thread will terminate anyway. // Just wake early, the thread will terminate anyway.
} }
} }
}; };
copyThread.start(); copier.start();
return pipeOut; return pipeOut;
} }
@Override
InputStream getErrorStream() throws IOException {
return channel.getErrStream();
}
@Override
int getExitStatus() {
return exitStatus;
}
@Override
void close() {
if (channel != null) {
try {
exitStatus = channel.getExitStatus();
if (channel.isConnected())
channel.disconnect();
} finally {
channel = null;
}
}
}
}
class SshFetchConnection extends BasePackFetchConnection { class SshFetchConnection extends BasePackFetchConnection {
private ChannelExec channel; private Connection conn;
private StreamCopyThread errorThread; private StreamCopyThread errorThread;
private int exitStatus; SshFetchConnection(Connection conn) throws TransportException {
SshFetchConnection() throws TransportException {
super(TransportGitSsh.this); super(TransportGitSsh.this);
this.conn = conn;
try { try {
final MessageWriter msg = new MessageWriter(); final MessageWriter msg = new MessageWriter();
setMessageWriter(msg); setMessageWriter(msg);
channel = exec(getOptionUploadPack()); conn.exec(getOptionUploadPack());
final InputStream upErr = channel.getErrStream(); final InputStream upErr = conn.getErrorStream();
errorThread = new StreamCopyThread(upErr, msg.getRawStream()); errorThread = new StreamCopyThread(upErr, msg.getRawStream());
errorThread.start(); errorThread.start();
init(channel.getInputStream(), outputStream(channel)); init(conn.getInputStream(), conn.getOutputStream());
connect(channel); conn.connect();
} catch (TransportException err) { } catch (TransportException err) {
close(); close();
@ -266,7 +320,8 @@ class SshFetchConnection extends BasePackFetchConnection {
readAdvertisedRefs(); readAdvertisedRefs();
} catch (NoRemoteRepositoryException notFound) { } catch (NoRemoteRepositoryException notFound) {
final String msgs = getMessages(); final String msgs = getMessages();
checkExecFailure(exitStatus, getOptionUploadPack(), msgs); checkExecFailure(conn.getExitStatus(), getOptionUploadPack(),
msgs);
throw cleanNotFound(notFound, msgs); throw cleanNotFound(notFound, msgs);
} }
} }
@ -286,40 +341,30 @@ public void close() {
} }
super.close(); super.close();
conn.close();
if (channel != null) {
try {
exitStatus = channel.getExitStatus();
if (channel.isConnected())
channel.disconnect();
} finally {
channel = null;
}
}
} }
} }
class SshPushConnection extends BasePackPushConnection { class SshPushConnection extends BasePackPushConnection {
private ChannelExec channel; private Connection conn;
private StreamCopyThread errorThread; private StreamCopyThread errorThread;
private int exitStatus; SshPushConnection(Connection conn) throws TransportException {
SshPushConnection() throws TransportException {
super(TransportGitSsh.this); super(TransportGitSsh.this);
this.conn = conn;
try { try {
final MessageWriter msg = new MessageWriter(); final MessageWriter msg = new MessageWriter();
setMessageWriter(msg); setMessageWriter(msg);
channel = exec(getOptionReceivePack()); conn.exec(getOptionReceivePack());
final InputStream rpErr = channel.getErrStream(); final InputStream rpErr = conn.getErrorStream();
errorThread = new StreamCopyThread(rpErr, msg.getRawStream()); errorThread = new StreamCopyThread(rpErr, msg.getRawStream());
errorThread.start(); errorThread.start();
init(channel.getInputStream(), outputStream(channel)); init(conn.getInputStream(), conn.getOutputStream());
connect(channel); conn.connect();
} catch (TransportException err) { } catch (TransportException err) {
close(); close();
@ -334,7 +379,8 @@ class SshPushConnection extends BasePackPushConnection {
readAdvertisedRefs(); readAdvertisedRefs();
} catch (NoRemoteRepositoryException notFound) { } catch (NoRemoteRepositoryException notFound) {
final String msgs = getMessages(); final String msgs = getMessages();
checkExecFailure(exitStatus, getOptionReceivePack(), msgs); checkExecFailure(conn.getExitStatus(), getOptionReceivePack(),
msgs);
throw cleanNotFound(notFound, msgs); throw cleanNotFound(notFound, msgs);
} }
} }
@ -354,16 +400,7 @@ public void close() {
} }
super.close(); super.close();
conn.close();
if (channel != null) {
try {
exitStatus = channel.getExitStatus();
if (channel.isConnected())
channel.disconnect();
} finally {
channel = null;
}
}
} }
} }
} }