/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ using System; using System.Collections.Generic; using System.IO; using System.Threading; using Avro.Generic; using Avro.IO; using Avro.Specific; using org.apache.avro.ipc; namespace Avro.ipc { public abstract class Requestor { private static readonly Schema Meta = MapSchema.CreateMap(PrimitiveSchema.NewInstance("bytes")); private static readonly GenericReader> MetaReader = new GenericReader>(Meta, Meta); public static readonly GenericWriter> MetaWriter = new GenericWriter>(Meta); private static readonly Dictionary RemoteHashes = new Dictionary(); private static readonly Dictionary RemoteProtocols = new Dictionary(); private static readonly object remoteHashLock = new object(); private static readonly object remoteProtocolsLock = new object(); private static readonly SpecificWriter HandshakeWriter = new SpecificWriter(new HandshakeRequest().Schema); private static readonly SpecificReader HandshakeReader = new SpecificReader(new HandshakeResponse().Schema, new HandshakeResponse().Schema); protected readonly Transceiver transceiver; private Protocol localProtocol; private volatile Protocol remoteProtocol; private volatile bool sendLocalText; private readonly object handshakeLock = new object(); private Thread handshakeThread; protected Requestor(Transceiver transceiver, Protocol protocol) { this.transceiver = transceiver; localProtocol = protocol; } public Protocol Local { get { return localProtocol; } protected set { localProtocol = value; } } public Transceiver Transceiver { get { return transceiver; } } public object Request(string messageName, object request) { transceiver.VerifyConnection(); var rpcRequest = new RpcRequest(messageName, request, new RpcContext()); CallFuture future = GetMessage(rpcRequest).Oneway.GetValueOrDefault() ? null : new CallFuture(); Request(rpcRequest, future); return future == null ? null : future.WaitForResult(); } // public void Request(String messageName, Object request, ICallback callback) //where TCallback : class //where TSpecificRecord : class, ISpecificRecord public void Request(String messageName, Object request, ICallback callback) { var rpcRequest = new RpcRequest(messageName, request, new RpcContext()); Request(rpcRequest, callback); } private void Request(RpcRequest request, ICallback callback) { Transceiver t = transceiver; if (!t.IsConnected) { Monitor.Enter(handshakeLock); handshakeThread = Thread.CurrentThread; try { if (!t.IsConnected) { var callFuture = new CallFuture(callback); IList bytes = request.GetBytes(Local, this); var transceiverCallback = new TransceiverCallback(this, request, callFuture, Local); t.Transceive(bytes, transceiverCallback); // Block until handshake complete callFuture.Wait(); Message message = GetMessage(request); if (message.Oneway.GetValueOrDefault()) { Exception error = callFuture.Error; if (error != null) { throw error; } } return; } } finally { if (Thread.CurrentThread == handshakeThread) { handshakeThread = null; Monitor.Exit(handshakeLock); } } } if (GetMessage(request).Oneway.GetValueOrDefault()) { t.LockChannel(); try { IList bytes = request.GetBytes(Local, this); t.WriteBuffers(bytes); if (callback != null) { callback.HandleResult(default(T)); } } finally { t.UnlockChannel(); } } else { IList bytes = request.GetBytes(Local, this); var transceiverCallback = new TransceiverCallback(this, request, callback, Local); t.Transceive(bytes, transceiverCallback); //if (Thread.CurrentThread == handshakeThread) //{ // Monitor.Exit(handshakeLock); //} } } private Message GetMessage(RpcRequest request) { return request.GetMessage(Local); } public abstract void WriteRequest(RecordSchema schema, Object request, Encoder encoder); public abstract object ReadResponse(Schema writer, Schema reader, Decoder decoder); public abstract Exception ReadError(Schema writer, Schema reader, Decoder decoder); public void WriteHandshake(Encoder outEncoder) { if (transceiver.IsConnected) return; var localHash = new MD5 {Value = localProtocol.MD5}; String remoteName = transceiver.RemoteName; MD5 remoteHash;// = RemoteHashes[remoteName]; lock (remoteHashLock) { if (!RemoteHashes.TryGetValue(remoteName, out remoteHash)) { // guess remote is local remoteHash = localHash; remoteProtocol = localProtocol; } } if (remoteProtocol == null) { lock (remoteProtocolsLock) { remoteProtocol = RemoteProtocols[remoteHash]; } } var handshake = new HandshakeRequest {clientHash = localHash, serverHash = remoteHash}; if (sendLocalText) handshake.clientProtocol = localProtocol.ToString(); var context = new RpcContext {HandshakeRequest = handshake}; handshake.meta = context.RequestHandshakeMeta; HandshakeWriter.Write(handshake, outEncoder); } private void setRemote(HandshakeResponse handshake) { remoteProtocol = Protocol.Parse(handshake.serverProtocol); MD5 remoteHash = handshake.serverHash; lock (remoteHashLock) { RemoteHashes[transceiver.RemoteName] = remoteHash; } lock (remoteProtocolsLock) { RemoteProtocols[remoteHash] = remoteProtocol; } } public Protocol GetRemote() { if (remoteProtocol != null) return remoteProtocol; // already have it lock (remoteHashLock) { MD5 remoteHash; if (RemoteHashes.TryGetValue(transceiver.RemoteName, out remoteHash)) { lock (remoteProtocolsLock) { remoteProtocol = RemoteProtocols[remoteHash]; if (remoteProtocol != null) return remoteProtocol; // already cached } } } Monitor.Enter(handshakeLock); try { // force handshake var bbo = new ByteBufferOutputStream(); // direct because the payload is tiny. Encoder outp = new BinaryEncoder(bbo); WriteHandshake(outp); outp.WriteInt(0); // empty metadata outp.WriteString(""); // bogus message name IList response = Transceiver.Transceive(bbo.GetBufferList()); var bbi = new ByteBufferInputStream(response); var inp = new BinaryDecoder(bbi); ReadHandshake(inp); return remoteProtocol; } finally { Monitor.Exit(handshakeLock); } } private bool ReadHandshake(BinaryDecoder input) { if (Transceiver.IsConnected) return true; bool established = false; HandshakeResponse handshake = HandshakeReader.Read(null, input); switch (handshake.match) { case HandshakeMatch.BOTH: established = true; sendLocalText = false; break; case HandshakeMatch.CLIENT: setRemote(handshake); established = true; sendLocalText = false; break; case HandshakeMatch.NONE: setRemote(handshake); sendLocalText = true; break; default: throw new AvroRuntimeException("Unexpected match: " + handshake.match); } if (established) transceiver.Remote = remoteProtocol; return established; } private class Response { private readonly Requestor requestor; private readonly RpcRequest request; private readonly BinaryDecoder input; public Response(Requestor requestor, RpcRequest request, BinaryDecoder input) { this.requestor = requestor; this.request = request; this.input = input; } public Object getResponse() { Message lm = request.GetMessage(requestor.Local); Message rm; if (!requestor.remoteProtocol.Messages.TryGetValue(request.GetMessage(requestor.Local).Name, out rm)) throw new AvroRuntimeException ("Not a remote message: " + request.GetMessage(requestor.Local).Name); Transceiver t = requestor.Transceiver; if ((lm.Oneway.GetValueOrDefault() != rm.Oneway.GetValueOrDefault()) && t.IsConnected) throw new AvroRuntimeException ("Not both one-way messages: " + request.GetMessage(requestor.Local)); if (lm.Oneway.GetValueOrDefault() && t.IsConnected) return null; // one-way w/ handshake RpcContext context = request.Context; context.ResponseCallMeta = MetaReader.Read(null, input); if (!input.ReadBoolean()) { // no error Object response = requestor.ReadResponse(rm.Response, lm.Response, input); context.Response = response; return response; } Exception error = requestor.ReadError(rm.SupportedErrors, lm.SupportedErrors, input); context.Error = error; throw error; } } private class TransceiverCallback : ICallback> { private readonly Requestor requestor; private readonly RpcRequest request; private readonly ICallback callback; private readonly Protocol local; public TransceiverCallback(Requestor requestor, RpcRequest request, ICallback callback, Protocol local) { this.requestor = requestor; this.request = request; this.callback = callback; this.local = local; } public void HandleResult(IList result) { var bbi = new ByteBufferInputStream(result); var input = new BinaryDecoder(bbi); if (!requestor.ReadHandshake(input)) { // Resend the handshake and return var handshake = new RpcRequest(request); IList requestBytes = handshake.GetBytes(requestor.Local, requestor); var transceiverCallback = new TransceiverCallback(requestor, handshake, callback, local); requestor.Transceiver.Transceive(requestBytes, transceiverCallback); return; } // Read response; invoke callback var response = new Response(requestor, request, input); try { Object responseObject; try { responseObject = response.getResponse(); } catch (Exception e) { if (callback != null) { callback.HandleException(e); } return; } if (callback != null) { callback.HandleResult((T) responseObject); } } catch { //LOG.error("Error in callback handler: " + t, t); } } public void HandleException(Exception exception) { callback.HandleException(exception); } } } }